阿犇

记录生活中的点点滴滴

0%

八股搭建神经网络

使用八股搭建神经网络

用Tensorflow API: tf. keras

六步法搭建神经网络

第一步:import相关模块,如import tensorflow as tf。

1
2
3
4
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt

第二步:指定输入网络的训练集和测试集,如指定训练集的输入x_train和标签y_train,测试集的输
入x_test和标签y_test。

1
2
3
mnist = tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()
x_train,x_test=x_train/255.0,x_test/255.0

第三步:逐层搭建网络结构,model = tf.keras.models.Sequential()。

1
2
3
4
5
model=tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128,activation='relu'),
tf.keras.layers.Dense(10,activation='softmax')
])

第四步:在model.compile()中配置训练方法,选择训练时使用的优化器、损失函数和最终评价指标。

1
2
3
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])

断点续训,读取模型

定义存放模型的路径和文件名,命名为ckpt文件

生成ckpt文件时会同步生成index索引表,所以判断索引表是否存在,来判断是否存在模型参数

如有索引表,则直接读取ckpt文件中的模型参数

1
2
3
4
5
checkpoint_save_path="./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path+'.index'):
print('--------------load the model---------------')
model.load_weights(checkpoint_save_path)

--------------load the model---------------

保存模型

第五步:在model.fit()中执行训练过程,告知训练集和测试集的输入值和标签、每个batch的大小(batchsize)和数据集的迭代次数(epoch)。

第六步:使用model.summary()打印网络结构,统计参数数目。

1
2
3
4
5
6
7
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)

history = model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])
model.summary()

Epoch 1/5
1875/1875 [==============================] - 2s 886us/step - loss: 0.0118 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.0801 - val_sparse_categorical_accuracy: 0.9797
Epoch 2/5
1875/1875 [==============================] - 1s 667us/step - loss: 0.0101 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.0899 - val_sparse_categorical_accuracy: 0.9798
Epoch 3/5
1875/1875 [==============================] - 1s 695us/step - loss: 0.0086 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.0879 - val_sparse_categorical_accuracy: 0.9807
Epoch 4/5
1875/1875 [==============================] - 1s 694us/step - loss: 0.0095 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.0979 - val_sparse_categorical_accuracy: 0.9782
Epoch 5/5
1875/1875 [==============================] - 1s 687us/step - loss: 0.0060 - sparse_categorical_accuracy: 0.9981 - val_loss: 0.0997 - val_sparse_categorical_accuracy: 0.9800
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________

设置输出格式

np.set printoptions(threshold=超过多少省略显示)

np.set_printoptions(threshold=np.inf) # np. inf表示无限大

1
2
# np.set_printoptions(threshold=np.inf)
np.set_printoptions(threshold=10)
1
print(model.trainable_variables)
[<tf.Variable 'dense/kernel:0' shape=(784, 128) dtype=float32, numpy=
array([[-0.05238347, -0.06266207, -0.07440181, ...,  0.0417512 ,
        -0.05344176,  0.02906115],
       [-0.00369083,  0.01099763, -0.00766775, ...,  0.07014749,
        -0.03556629,  0.01387362],
       [-0.07415341, -0.01817401,  0.00419831, ...,  0.06153186,
        -0.01100198,  0.0544705 ],
       ...,
       [ 0.06579664, -0.06812809, -0.05979012, ..., -0.0540199 ,
         0.04981285,  0.066493  ],
       [-0.0601579 ,  0.06772352, -0.0692725 , ..., -0.04544504,
        -0.08102902,  0.02741539],
       [-0.044352  , -0.07048865,  0.00934549, ...,  0.032233  ,
        -0.00784087,  0.05623148]], dtype=float32)>, <tf.Variable 'dense/bias:0' shape=(128,) dtype=float32, numpy=
array([-0.14855091, -0.07285158, -0.09825671, ...,  0.07643671,
       -0.1354494 ,  0.08794942], dtype=float32)>, <tf.Variable 'dense_1/kernel:0' shape=(128, 10) dtype=float32, numpy=
array([[ 2.43996501e-01, -2.15584278e-01,  2.29047656e-01, ...,
        -6.07578419e-02,  9.98528376e-02,  1.91292807e-01],
       [ 2.44029671e-01,  4.49361429e-02, -9.02478278e-01, ...,
        -9.88776796e-03,  4.69152890e-02,  1.93796545e-01],
       [-1.17883611e+00,  3.15114379e-01,  3.47505659e-01, ...,
         3.69858891e-01,  3.26448739e-01, -8.94050360e-01],
       ...,
       [-5.94317436e-01, -2.05278710e-01, -7.61512935e-01, ...,
         3.76689643e-01,  1.37598768e-01,  1.51904374e-01],
       [-2.08765984e-01,  1.04028150e-01,  1.08290091e-01, ...,
        -6.63495797e-04,  1.37945980e-01,  1.77999035e-01],
       [ 4.85304482e-02, -2.46528938e-01, -5.67862451e-01, ...,
        -4.70214367e-01, -3.69332522e-01,  1.24029376e-01]], dtype=float32)>, <tf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=
array([-0.1546798 , -0.22612654, -0.10529487, -0.1593509 ,  0.18012638,
        0.10057895, -0.01941594, -0.1304893 ,  0.35250106,  0.02564381],
      dtype=float32)>]

参数保存

1
2
3
4
5
6
file=open('./weights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name)+'\n')
file.write(str(v.shape)+'\n')
file.write(str(v.numpy())+'\n')
file.close()

显示训练集和验证集的acc和loss曲线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['val_sparse_categorical_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss']

plt.subplot(1,2,1)
plt.plot(acc,label='Training Accuracy')
plt.plot(val_acc,label='Validation Accuracy')
plt.title("Training and Validation Accuracy")
plt.legend()

plt.subplot(1,2,2)
plt.plot(loss,label='Training loss')
plt.plot(val_loss,label='Validation Loss')
plt.title('Training and Validation Loss')

plt.legend()
plt.show()

您的支持是我继续创作的最大动力!