使用八股搭建神经网络
用Tensorflow API: tf. keras
六步法搭建神经网络
第一步:import相关模块,如import tensorflow as tf。
1 | import tensorflow as tf |
第二步:指定输入网络的训练集和测试集,如指定训练集的输入x_train和标签y_train,测试集的输
入x_test和标签y_test。
1 | mnist = tf.keras.datasets.mnist |
第三步:逐层搭建网络结构,model = tf.keras.models.Sequential()。
1 | model=tf.keras.models.Sequential([ |
第四步:在model.compile()中配置训练方法,选择训练时使用的优化器、损失函数和最终评价指标。
1 | model.compile(optimizer='adam', |
断点续训,读取模型
定义存放模型的路径和文件名,命名为ckpt文件
生成ckpt文件时会同步生成index索引表,所以判断索引表是否存在,来判断是否存在模型参数
如有索引表,则直接读取ckpt文件中的模型参数
1 | checkpoint_save_path="./checkpoint/mnist.ckpt" |
--------------load the model---------------
保存模型
第五步:在model.fit()中执行训练过程,告知训练集和测试集的输入值和标签、每个batch的大小(batchsize)和数据集的迭代次数(epoch)。
第六步:使用model.summary()打印网络结构,统计参数数目。
1 | cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, |
Epoch 1/5
1875/1875 [==============================] - 2s 886us/step - loss: 0.0118 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.0801 - val_sparse_categorical_accuracy: 0.9797
Epoch 2/5
1875/1875 [==============================] - 1s 667us/step - loss: 0.0101 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.0899 - val_sparse_categorical_accuracy: 0.9798
Epoch 3/5
1875/1875 [==============================] - 1s 695us/step - loss: 0.0086 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.0879 - val_sparse_categorical_accuracy: 0.9807
Epoch 4/5
1875/1875 [==============================] - 1s 694us/step - loss: 0.0095 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.0979 - val_sparse_categorical_accuracy: 0.9782
Epoch 5/5
1875/1875 [==============================] - 1s 687us/step - loss: 0.0060 - sparse_categorical_accuracy: 0.9981 - val_loss: 0.0997 - val_sparse_categorical_accuracy: 0.9800
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
_________________________________________________________________
dense (Dense) (None, 128) 100480
_________________________________________________________________
dense_1 (Dense) (None, 10) 1290
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
设置输出格式
np.set printoptions(threshold=超过多少省略显示)
np.set_printoptions(threshold=np.inf) # np. inf表示无限大
1 | # np.set_printoptions(threshold=np.inf) |
1 | print(model.trainable_variables) |
[<tf.Variable 'dense/kernel:0' shape=(784, 128) dtype=float32, numpy=
array([[-0.05238347, -0.06266207, -0.07440181, ..., 0.0417512 ,
-0.05344176, 0.02906115],
[-0.00369083, 0.01099763, -0.00766775, ..., 0.07014749,
-0.03556629, 0.01387362],
[-0.07415341, -0.01817401, 0.00419831, ..., 0.06153186,
-0.01100198, 0.0544705 ],
...,
[ 0.06579664, -0.06812809, -0.05979012, ..., -0.0540199 ,
0.04981285, 0.066493 ],
[-0.0601579 , 0.06772352, -0.0692725 , ..., -0.04544504,
-0.08102902, 0.02741539],
[-0.044352 , -0.07048865, 0.00934549, ..., 0.032233 ,
-0.00784087, 0.05623148]], dtype=float32)>, <tf.Variable 'dense/bias:0' shape=(128,) dtype=float32, numpy=
array([-0.14855091, -0.07285158, -0.09825671, ..., 0.07643671,
-0.1354494 , 0.08794942], dtype=float32)>, <tf.Variable 'dense_1/kernel:0' shape=(128, 10) dtype=float32, numpy=
array([[ 2.43996501e-01, -2.15584278e-01, 2.29047656e-01, ...,
-6.07578419e-02, 9.98528376e-02, 1.91292807e-01],
[ 2.44029671e-01, 4.49361429e-02, -9.02478278e-01, ...,
-9.88776796e-03, 4.69152890e-02, 1.93796545e-01],
[-1.17883611e+00, 3.15114379e-01, 3.47505659e-01, ...,
3.69858891e-01, 3.26448739e-01, -8.94050360e-01],
...,
[-5.94317436e-01, -2.05278710e-01, -7.61512935e-01, ...,
3.76689643e-01, 1.37598768e-01, 1.51904374e-01],
[-2.08765984e-01, 1.04028150e-01, 1.08290091e-01, ...,
-6.63495797e-04, 1.37945980e-01, 1.77999035e-01],
[ 4.85304482e-02, -2.46528938e-01, -5.67862451e-01, ...,
-4.70214367e-01, -3.69332522e-01, 1.24029376e-01]], dtype=float32)>, <tf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=
array([-0.1546798 , -0.22612654, -0.10529487, -0.1593509 , 0.18012638,
0.10057895, -0.01941594, -0.1304893 , 0.35250106, 0.02564381],
dtype=float32)>]
参数保存
1 | file=open('./weights.txt','w') |
显示训练集和验证集的acc和loss曲线
1 | acc=history.history['sparse_categorical_accuracy'] |