1 keras是什么?
Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。
Keras 的主要优点是: Keras 是为人类而不是为机器设计的 API,它把用户体验放在首要和中心位置,能够以最短时间把你的想法转换为实验结果。
这一列文档主要讲述keras的一些入门知识,keras环境搭建可以参考https://keras.io/zh/ 的安装指引。
2 使用Keras搭建神经网络的步骤
2 示例
搭建一个最简单的手写数字识别MNIST
step1:选择顺序模型并初始化
model =
keras.
models.
Sequential()
step2:构建网络层
添加一个全连接层与一个激活层
model.
add(
keras.
layers.
Flatten(
input_shape=(
x_train.
shape[
1],
x_train.
shape[
2])))
model.
add(
keras.
layers.
Dense(
units=
784,
activation=
"relu",
input_dim=
784))
model.
add(
keras.
layers.
Dense(
units=
10,
activation=
"softmax"))
step3:编译
model.
compile(
optimizer=
"Adam",
loss=
'sparse_categorical_crossentropy',
metrics=[
'accuracy'])
step4: 训练模型
model.
fit(
x_train,
y_train,
batch_size=
64,
epochs=
5)
step5: 预测
model.
evaluate(
x_valid,
y_valid)
完整代码如下:
import tensorflow as tf
import tensorflow.
keras as keras
print(
keras.
__version__)
(
x_train,
y_train), (
x_valid,
y_valid) =
keras.
datasets.
mnist.
load_data()
assert x_train.
shape == (
60000,
28,
28)
assert x_valid.
shape == (
10000,
28,
28)
assert y_train.
shape == (
60000,)
assert y_valid.
shape == (
10000,)
# step1: use sequential
model =
keras.
models.
Sequential()
# step2: add layer
model.
add(
keras.
layers.
Flatten(
input_shape=(
x_train.
shape[
1],
x_train.
shape[
2])))
model.
add(
keras.
layers.
Dense(
units=
784,
activation=
"relu",
input_dim=
784))
model.
add(
keras.
layers.
Dense(
units=
10,
activation=
"softmax"))
# step3: compile model
model.
compile(
optimizer=
"Adam",
loss=
'sparse_categorical_crossentropy',
metrics=[
'accuracy'])
print(
"model:")
model.
summary()
# step4: train
model.
fit(
x_train,
y_train,
batch_size=
64,
epochs=
5)
# step5: evaluate model
model.
evaluate(
x_valid,
y_valid)
# save model
model.
save(
'keras_mnist.h5')
执行log如下:
model:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
dense (Dense) (None, 784) 615440
dense_1 (Dense) (None, 10) 7850
=================================================================
Total params: 623,290
Trainable params: 623,290
Non-trainable params: 0
_________________________________________________________________
Epoch 1/5
938/938 [==============================] - 3s 3ms/step - loss: 3.2538 - accuracy: 0.9165
Epoch 2/5
938/938 [==============================] - 3s 3ms/step - loss: 0.3467 - accuracy: 0.9538
Epoch 3/5
938/938 [==============================] - 3s 3ms/step - loss: 0.2521 - accuracy: 0.9595
Epoch 4/5
938/938 [==============================] - 3s 3ms/step - loss: 0.2533 - accuracy: 0.9591
Epoch 5/5
938/938 [==============================] - 3s 3ms/step - loss: 0.2359 - accuracy: 0.9599
313/313 [==============================] - 1s 2ms/step - loss: 0.3514 - accuracy: 0.9494
这样我们就得到一个最简单的mnist模型,准确率为0.9494,从日志中还可以看出,最终的准确率并不是训练过程中最高的,我们可以提前中止训练,后面章节会讲到。
3 模型的检验
我们可以直接使用已经训练好的模型进行预测。
方法1:model.evaluate()方法
import tensorflow as tf
import tensorflow.
keras as keras
(
x_train,
y_train), (
x_valid,
y_valid) =
keras.
datasets.
mnist.
load_data()
assert x_train.
shape == (
60000,
28,
28)
assert x_valid.
shape == (
10000,
28,
28)
assert y_train.
shape == (
60000,)
assert y_valid.
shape == (
10000,)
# load model
model =
keras.
models.
load_model(
'keras_mnist.h5')
# evaluate model
model.
evaluate(
x_valid,
y_valid)
执行后:
313/313 [==============================] - 1s 2ms/step - loss: 0.3514 - accuracy: 0.9494
方法2:model.predict()方法
这样验证不是很直观,我们也可采用predict方法
如下图所示,是一张28x28的手写数字9,我们导入到模型中预测:
代码如下:
# predict_num.py
import tensorflow as tf
import tensorflow.keras as keras
from PIL import Image
import os,sys
import numpy as np
if __name__ == '__main__':
img_name = sys.argv[1]
img = Image.open(img_name)
img = np.array(img)
print(img.shape)
img = np.reshape(img, (-1, 28, 28))
# load model
model = keras.models.load_model('keras_mnist.h5')
predict_num = np.argmax(model.predict(img), axis = 1)
print("predict num is %d" % predict_num)
执行方法:
python predict_num.py xx.bmp
结果如下:
1/1 [==============================] - 0s 90ms/step