Char8-Keras高层接口
第八章中讲解的是高层接口Keras的使用。Keras的几个特点
Python
语言开发- 前后端分离
- 后端基于现有的
TF、CNTK
等框架 - 前端有自己的接口
API
- 后端基于现有的
TF
的高层唯一API
接口Keras
被实现在tf.keras
子模块中
常见功能模块
Keras提供常见的神经网络类和函数
- 数据集加载函数
- 网络层类
- 模型容器
- 损失函数
- 优化器类
- 经典模型
常见网络层
- 张量方式
tf.nn
模块中 - 层方式
tf.keras.layers
提供大量的接口,需要完成__call__()
- 全连接层
- 激活含水层
- 池化层
- 卷积层
1 | import tensorflow as tf |
网络容器
主要使用的Sequential
类
2层全连接层加上激活函数层通过Sequntial
容器构成一个网络
1 | import tensorflow as tf |
模型装配、训练和测试
装配
通过两个主要的类实现:
-
keras.Model,网络的母类,
Sequentail
类是其子类 -
keras.layers.Layer,网络层的母类
通过compile()
函数指定优化器、损失函数等
1 | # 创建全连接层网络 |
训练
通过fit()
函数实现
train_db
为tf.data_Dataset
对象- epoch:训练5个epoch,每2个epoch验证一次
1 | history = network.fit(train_db, epoch=5, validation=val_db,validation_freq=2) |
测试
1 | x, y = next(iter(db_test)) |
模型加载
张量方式
文件中保存的仅仅是参数张量的数值,没有其他的结构参数,需要使用相同的网络结构才能恢复网络数据,一般在拥有源文件的情况下使用。
1 | network.save_weights('weights.ckpt') # 保存模型到参数文件上 |
网络方式
- 不需要网络源文件
- 仅仅是需要模型参数文件就可以恢复网络模型
- 通过Model.save()
1 | network.save('model.h5') |
SaveModel方式
通过 tf.keras.experimental.export_saved_model(network, path)
即可将模型以 SavedModel
方式保存到 path
目录中:
1 | tf.keras.experimental.export_saved_model(network, 'model-savedmodel') # 保存模型结构与参数 |
自定义类
自定义网络类
需要实现call()
和__init__()
方法
1 | # 初始化工作 |
自定义网络
1 | network = Squential([MyDense(784, 256), # 使用自定义的网络类MyDense |
- 通过堆叠使用自定义的网络类
- 5层全连接没有偏置张量,同时使用激活啊函数ReLU
使用基类实现
可以继承基类来实现任意逻辑的自定义网络类
1 | class MyModel(keras.Model): |
测量工具
- 新建测量器
loss_meter = metrics.Mean()
- 写入数据
loss_meter.update_state(float(loss))
- 读取统计信息
loss_meter.result()
- 清除历史状态的信息
loss_meter.reset_states()
可视化
- TensorBoard
- Visdom
模型端
需要写入监控数据、图片数据、查看数据的直方分布图、文本信息。
1 | # 监控标量数据 |
浏览器端
通过tensorboard --logdir path
来指定web后端监控的文件目录,浏览器端口默认是6006
1 | # 查看张量的数据脂肪分布图和打印文本信息 |