模型 滤波器
简单来说:所谓模型就是一个滤波器,训练的权重就是滤波系数,输入经过滤波器后得到一个输出。所以嵌入式AI部署一般就是解析模型得到“滤波系数”,输入信号进行一系列类似"滤波"运算,得到最终输出。
所以需要搞明白模型怎么解析,这篇讲TFllite模型的格式以及它的解析。
1 TFLite格式简介
Tflite文件由Tensorflow提供的TOCO工具生成的轻量级模型,存储格式是flatbuffer,flatbuffer是google开源的一种二进制序列化格式,与protobuf类似。
下图(来自于参考2)描述了 模型训练->模型转化为Tflite格式->模型部署 的大致流程。从图中可以看到获取Tflite的三种方式:
# TensorFlow 2.x
tf.
lite.
TFLiteConverter.
from_saved_model():
# 由SavedModel转化
tf.
lite.
TFLiteConverter.
from_keras_model():
# 由Keras model转化
tf.
lite.
TFLiteConverter.
from_concrete_functions():
# 由具体函数转化
2 TFLite格式分析
例如我们已经训练得到了一个tflite模型(mnist_model.tflite),下面分析其格式:
方法1: Netron查看tflite模型
Netron 是一款常见的可视化工具,支持网页查看常见的AI模型,支持非常丰富的格式(ONNX, Tensorflow, Pytorch, Keras, Caffe等)
网页地址: https://netron.app/
将mnist_model.tflite导入,可以得到下图,可见mnist_model.tflite含有一个Reshape层,2个FullyConnected层,一个Relu层以及一个Softmax层
方法2:利用flatbuffer开源工具flatc
Tflite格式是flatbuffer格式,其优点是:解码速度极快、内存占用小,缺点是:数据没有可读性,需要借助其他工具实现可视化。
可使用google flatbuffer开源工具flatc,flatc可以实现tflite格式到jason文件的自动转换,解析时需要用到schema.fbs协议文件。
step1:安装flatc
# flatbuffer源码 https://github.com/google/flatbuffers
# 下载后进入文件夹,执行如下命令
mkdir build &&
cd build
cmake ../
# 生成Makefile
make # 编译
make install
# 安装flatcstep2:获取schema.fbs
schema.fbs是二进制协议文件,一般改动较小。直接从Tensorflow的源码中获取(如果后面的转换步骤出现问题,可以找到对应TensorFlow版本的schema.fbs文件试试)
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs
step3:转化为json
flatc
-t schema.fbs
-- mnist_model.tflite
这样获取得到mnist_model.json:
{
version:
3,
operator_codes: [
...
],
subgraphs: [
tensors: [],
inputs: [],
outputs: [],
operators: [],
],
description:
"MLIR Converted.",
buffers: [],
}
这个数据结构描述了tflite的整体框架及所有细节,这个放到另一篇文档里讲。
方法3: 利用tensorflow提供的接口分析
tf.lite.Interpreter可以读取tflite模型,但是python接口没有描述模型结构(op node节点间的连接关系)
import tensorflow as tf
import numpy as np
#加载模型
interpreter =
tf.
lite.
Interpreter(
model_path=
"./mnist_model.tflite")
interpreter.
allocate_tensors()
# 模型输入和输出细节
# input_details = interpreter.get_input_details()
# output_details = interpreter.get_output_details()
# 获取模型的tensor的详细信息
tensor =
interpreter.
get_tensor_details()
print(
tensor)
得到的结果如下:
[
{
'name':
'serving_default_flatten_2_input:0',
'index':
0,
'shape':
array([
1,
28,
28],
dtype=int32),
'shape_signature':
array([
-1,
28,
28],
dtype=int32),
'dtype':
<class 'numpy.float32'>,
'quantization': (
0.0,
0),
'quantization_parameters': {
'scales':
array([],
dtype=float32),
'zero_points':
array([],
dtype=int32),
'quantized_dimension':
0
},
'sparsity_parameters': {}
},
{
'name':
'sequential_2/dense_5/BiasAdd/ReadVariableOp',
'index':
1,
'shape':
array([
10],
dtype=int32),
'shape_signature':
array([
10],
dtype=int32),
'dtype':
<class 'numpy.float32'>,
'quantization': (
0.0,
0),
'quantization_parameters': {
'scales':
array([],
dtype=float32),
'zero_points':
array([],
dtype=int32),
'quantized_dimension':
0
},
'sparsity_parameters': {}
},
{
'name':
'sequential_2/dense_4/BiasAdd/ReadVariableOp',
'index':
2,
'shape':
array([
128],
dtype=int32),
'shape_signature':
array([
128],
dtype=int32),
'dtype':
<class 'numpy.float32'>,
'quantization': (
0.0,
0),
'quantization_parameters': {
'scales':
array([],
dtype=float32),
'zero_points':
array([],
dtype=int32),
'quantized_dimension':
0
},
'sparsity_parameters': {}
},
...
]
方法4:文本解析tflite文件
Flatbuffer格式的tflite文件,转成可读的python dict格式,并可描述模型完整推理流程。
直接下载Tensorflow提供的visualize.py工具,
下载地址:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py
更多回帖