芯来科技
直播中

曹利娟

7年用户 957经验值
私信 关注
[经验]

Keras搭建神经网络的一般步骤

1 keras是什么?

Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。
Keras 的主要优点是: Keras 是为人类而不是为机器设计的 API,它把用户体验放在首要和中心位置,能够以最短时间把你的想法转换为实验结果。
这一列文档主要讲述keras的一些入门知识,keras环境搭建可以参考https://keras.io/zh/ 的安装指引。
2 使用Keras搭建神经网络的步骤



2 示例

搭建一个最简单的手写数字识别MNIST
step1:选择顺序模型并初始化
model = keras.models.Sequential()step2:构建网络层
添加一个全连接层与一个激活层
model.add(keras.layers.Flatten(input_shape=(x_train.shape[1], x_train.shape[2])))
model.add(keras.layers.Dense(units=784, activation="relu", input_dim=784))
model.add(keras.layers.Dense(units=10, activation="softmax"))step3:编译
model.compile(optimizer="Adam", loss='sparse_categorical_crossentropy', metrics=['accuracy'])step4: 训练模型
model.fit(x_train, y_train, batch_size=64, epochs=5)step5: 预测
model.evaluate(x_valid, y_valid)
完整代码如下:
import tensorflow as tf
import tensorflow.keras as keras

print(keras.__version__)

(x_train, y_train), (x_valid, y_valid) = keras.datasets.mnist.load_data()
assert x_train.shape == (60000, 28, 28)
assert x_valid.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_valid.shape == (10000,)

# step1: use sequential
model = keras.models.Sequential()

# step2: add layer
model.add(keras.layers.Flatten(input_shape=(x_train.shape[1], x_train.shape[2])))
model.add(keras.layers.Dense(units=784, activation="relu", input_dim=784))
model.add(keras.layers.Dense(units=10, activation="softmax"))

# step3: compile model
model.compile(optimizer="Adam", loss='sparse_categorical_crossentropy', metrics=['accuracy'])

print("model:")
model.summary()

# step4: train
model.fit(x_train, y_train, batch_size=64, epochs=5)

# step5: evaluate model
model.evaluate(x_valid, y_valid)

# save model
model.save('keras_mnist.h5')
执行log如下:
model:
Model: "sequential"
_________________________________________________________________
Layer (type)                Output Shape              Param #
=================================================================
flatten (Flatten)           (None, 784)               0

dense (Dense)               (None, 784)               615440

dense_1 (Dense)             (None, 10)                7850

=================================================================
Total params: 623,290
Trainable params: 623,290
Non-trainable params: 0
_________________________________________________________________
Epoch 1/5
938/938 [==============================] - 3s 3ms/step - loss: 3.2538 - accuracy: 0.9165
Epoch 2/5
938/938 [==============================] - 3s 3ms/step - loss: 0.3467 - accuracy: 0.9538
Epoch 3/5
938/938 [==============================] - 3s 3ms/step - loss: 0.2521 - accuracy: 0.9595
Epoch 4/5
938/938 [==============================] - 3s 3ms/step - loss: 0.2533 - accuracy: 0.9591
Epoch 5/5
938/938 [==============================] - 3s 3ms/step - loss: 0.2359 - accuracy: 0.9599
313/313 [==============================] - 1s 2ms/step - loss: 0.3514 - accuracy: 0.9494这样我们就得到一个最简单的mnist模型,准确率为0.9494,从日志中还可以看出,最终的准确率并不是训练过程中最高的,我们可以提前中止训练,后面章节会讲到。

3 模型的检验

我们可以直接使用已经训练好的模型进行预测。
方法1:model.evaluate()方法
import tensorflow as tf
import tensorflow.keras as keras

(x_train, y_train), (x_valid, y_valid) = keras.datasets.mnist.load_data()
assert x_train.shape == (60000, 28, 28)
assert x_valid.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_valid.shape == (10000,)

# load model
model = keras.models.load_model('keras_mnist.h5')

# evaluate model
model.evaluate(x_valid, y_valid)执行后:
313/313 [==============================] - 1s 2ms/step - loss: 0.3514 - accuracy: 0.9494
方法2:model.predict()方法
这样验证不是很直观,我们也可采用predict方法
如下图所示,是一张28x28的手写数字9,我们导入到模型中预测:

代码如下:
# predict_num.py
import tensorflow as tf
import tensorflow.keras as keras

from PIL import Image
import os,sys
import numpy as np

if __name__ == '__main__':
    img_name  = sys.argv[1]
    img = Image.open(img_name)
    img = np.array(img)
    print(img.shape)
    img = np.reshape(img, (-1, 28, 28))
    # load model
    model = keras.models.load_model('keras_mnist.h5')
    predict_num = np.argmax(model.predict(img), axis = 1)
    print("predict num is %d" % predict_num)执行方法:
python predict_num.py xx.bmp结果如下:
1/1 [==============================] - 0s 90ms/step

更多回帖

发帖
×
20
完善资料,
赚取积分