芯来科技
直播中

刘芳

7年用户 1266经验值
私信 关注
[经验]

TFllite模型的格式简介

简单来说:所谓模型就是一个滤波器,训练的权重就是滤波系数,输入经过滤波器后得到一个输出。所以嵌入式AI部署一般就是解析模型得到“滤波系数”,输入信号进行一系列类似"滤波"运算,得到最终输出。
所以需要搞明白模型怎么解析,这篇讲TFllite模型的格式以及它的解析。
1 TFLite格式简介

Tflite文件由Tensorflow提供的TOCO工具生成的轻量级模型,存储格式是flatbuffer,flatbuffer是google开源的一种二进制序列化格式,与protobuf类似。
下图(来自于参考2)描述了 模型训练->模型转化为Tflite格式->模型部署 的大致流程。从图中可以看到获取Tflite的三种方式:
#  TensorFlow 2.x
tf.lite.TFLiteConverter.from_saved_model():          # 由SavedModel转化
tf.lite.TFLiteConverter.from_keras_model():           # 由Keras model转化
tf.lite.TFLiteConverter.from_concrete_functions(): # 由具体函数转化

2 TFLite格式分析

例如我们已经训练得到了一个tflite模型(mnist_model.tflite),下面分析其格式:
方法1: Netron查看tflite模型
Netron 是一款常见的可视化工具,支持网页查看常见的AI模型,支持非常丰富的格式(ONNX, Tensorflow, Pytorch, Keras, Caffe等)
网页地址: https://netron.app/
将mnist_model.tflite导入,可以得到下图,可见mnist_model.tflite含有一个Reshape层,2个FullyConnected层,一个Relu层以及一个Softmax层



方法2:利用flatbuffer开源工具flatc
Tflite格式是flatbuffer格式,其优点是:解码速度极快、内存占用小,缺点是:数据没有可读性,需要借助其他工具实现可视化。
可使用google flatbuffer开源工具flatc,flatc可以实现tflite格式到jason文件的自动转换,解析时需要用到schema.fbs协议文件。
step1:安装flatc
# flatbuffer源码 https://github.com/google/flatbuffers
# 下载后进入文件夹,执行如下命令
mkdir build && cd build
cmake ../              # 生成Makefile
make                   # 编译
make install           # 安装flatcstep2:获取schema.fbs
schema.fbs是二进制协议文件,一般改动较小。直接从Tensorflow的源码中获取(如果后面的转换步骤出现问题,可以找到对应TensorFlow版本的schema.fbs文件试试)
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs
step3:转化为json
flatc -t schema.fbs -- mnist_model.tflite这样获取得到mnist_model.json:
{
    version: 3,
    operator_codes: [
    ...
    ],
    subgraphs: [
        tensors: [],
        inputs: [],
        outputs: [],
        operators: [],
    ],
    description: "MLIR Converted.",   
    buffers: [],
}这个数据结构描述了tflite的整体框架及所有细节,这个放到另一篇文档里讲。
方法3: 利用tensorflow提供的接口分析
tf.lite.Interpreter可以读取tflite模型,但是python接口没有描述模型结构(op node节点间的连接关系)
import tensorflow as tf
import numpy as np
#加载模型
interpreter = tf.lite.Interpreter(model_path="./mnist_model.tflite")
interpreter.allocate_tensors()

# 模型输入和输出细节
# input_details = interpreter.get_input_details()
# output_details = interpreter.get_output_details()

# 获取模型的tensor的详细信息
tensor = interpreter.get_tensor_details()

print(tensor)得到的结果如下:
[
{'name': 'serving_default_flatten_2_input:0',
  'index': 0,
  'shape': array([ 1, 28, 28], dtype=int32),
  'shape_signature': array([-1, 28, 28], dtype=int32),
  'dtype': <class 'numpy.float32'>,
  'quantization': (0.0, 0),
  'quantization_parameters': {
      'scales': array([], dtype=float32),
      'zero_points': array([], dtype=int32),
      'quantized_dimension': 0
    },
  'sparsity_parameters': {}
  },

  {'name': 'sequential_2/dense_5/BiasAdd/ReadVariableOp',
   'index': 1,
   'shape': array([10], dtype=int32),
   'shape_signature': array([10], dtype=int32),
   'dtype': <class 'numpy.float32'>,
   'quantization': (0.0, 0),
   'quantization_parameters': {
        'scales': array([], dtype=float32),
        'zero_points': array([], dtype=int32),
        'quantized_dimension': 0
      },
    'sparsity_parameters': {}
  },

  {'name': 'sequential_2/dense_4/BiasAdd/ReadVariableOp',
   'index': 2,
   'shape': array([128], dtype=int32),
   'shape_signature': array([128], dtype=int32),
   'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0),
   'quantization_parameters': {
      'scales': array([], dtype=float32),
      'zero_points': array([], dtype=int32),
      'quantized_dimension': 0
    },
    'sparsity_parameters': {}
  },
  ...
]方法4:文本解析tflite文件
Flatbuffer格式的tflite文件,转成可读的python dict格式,并可描述模型完整推理流程。
直接下载Tensorflow提供的visualize.py工具,
下载地址:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py

更多回帖

发帖
×
20
完善资料,
赚取积分