tensorFlowLite 是一款TensorFlow用于移動設備和嵌入式設備的輕量級解決方案。
我們知道TensorFlow可以在多個平臺上運行,從機架式服務器到小型IoT設備。但是隨著近年來機器學習模型的廣泛使用,出現(xiàn)了在移動和嵌入式設備上部署它們的需求,而TensorFlowLite 允許設備端的機器學習模型的低延遲推斷。
飛凌OK1052-C開發(fā)板以高性能處理器i.MXRT1052作為硬件軀體,再用tensorflow_lite武裝軟件大腦,也搭上了開往機器學習,邊緣計算等朝陽行業(yè)領域的快車。
▲OK1052-C開發(fā)板接口圖
OK1052-C的用戶資料SDK中,已經(jīng)有了關(guān)于tensorflow_lite使用的demo例程,但是這些例程使用的都是現(xiàn)成訓練之后的實驗模型,并不適用于我們實際的應用場景。我們在實際應用項目中,必然是需要使用適合本項目的模型,網(wǎng)上現(xiàn)成模型資源下載或者自己訓練模型,然后搭載自己的應用程序,完成我們的應用項目需求。
由于模型的訓練需要算力的支持,通常我們在計算機上進行模型的訓練,我們將訓練好的模型稱為預訓練模型。我們也可以在網(wǎng)上download預訓練模型,然后通過格式轉(zhuǎn)換轉(zhuǎn)換為tflite格式的模型,也就是開發(fā)板可執(zhí)行的模型格式;也可以自己搭建計算網(wǎng)絡,通過訓練之后,形成預訓練模型,再轉(zhuǎn)換為tflite格式,運行到開發(fā)板。
今天我們通過一個簡單的例子,來介紹怎么建立計算網(wǎng)絡和進行模型的轉(zhuǎn)換。
01、常見的模型格式
我們常見的訓練模型格式有.PB,cpkt,saveModel,H5等,若想將模型運用于tensor flowlite,需要轉(zhuǎn)換為tflite格式,一般的轉(zhuǎn)換過程是:
由上圖可知,如果要將Checkpoints(cpkt)格式轉(zhuǎn)換為tflite,需經(jīng)過freeze_graph.py工具將cpkt格式模型轉(zhuǎn)換為Frozen GraphDef(.pb)格式,然后再經(jīng)過TFLite Converter轉(zhuǎn)換工具轉(zhuǎn)換為tflite。
02、建立模型并進行轉(zhuǎn)換
現(xiàn)在通過例子介紹如何將模型轉(zhuǎn)換為tflite格式,首先我們先建立兩個網(wǎng)絡流圖,并分別生成.pb和cpkt格式的模型。
1)建立計算網(wǎng)絡,并保存為.PB格式模型
# coding=UTF-8
import tensorflow as tf
import shutil
import os.path
from tensorflow.python.framework import graph_util
output_graph = "easy_model/add_model.pb"
#下面的過程你可以替換成CNN、RNN等你想做的訓練過程,這里只是簡單的一個計算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
# predictions = tf.greater(_y, 50, name="predictions") #比50大返回true,否則返回false
predictions = tf.add(_y, 10,name="predictions") #做一個加法運算
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print ("predictions :", sess.run(predictions, feed_dict={input_holder: [10.0]}))
graph_def = tf.get_default_graph().as_graph_def() #得到當前的圖的 GraphDef 部分,通過這個部分就可以完成重輸入層到輸出層的計算過程
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定
sess,
graph_def,
["predictions"] #需要保存節(jié)點的名字
)
with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化輸出
print("%d ops in the final graph." % len(output_graph_def.node))
print (predictions)
此計算網(wǎng)絡只是一個簡單的數(shù)學運算,不需要進行訓練,該運算公式為:
predictions = (input_holder * W1) + B1 + 10
其中input_holder是網(wǎng)絡的輸入節(jié)點,predictions是網(wǎng)絡的輸出節(jié)點,W1,B1是兩個變量,分別被賦值為5.0,,1.0
程序運行之后,會在easy_model/文件夾下生成add_model.pb模型文件。我們通過Netron軟件(Netron是一個很方便的軟件)來看一下add_model.pb模型文件中存儲的網(wǎng)絡圖及參數(shù)信息:
2)建立網(wǎng)絡圖,并保存為.ckpt格式模型:
# coding=UTF-8 支持中文編碼格式
import tensorflow as tf
import shutil
import os.path
MODEL_DIR = "easy_model"
MODEL_NAME = "model.ckpt"
if not tf.gfile.Exists(MODEL_DIR):
tf.gfile.MakeDirs(MODEL_DIR)
下面的過程你可以替換成CNN、RNN等你想做的訓練過程,這里只是簡單的一個計算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
#predictions = tf.greater(_y, 50, name="predictions")
predictions = tf.add(_y, 10,name="predictions")
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
print ("predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}))
saver.save(sess, os.path.join(MODEL_DIR, MODEL_NAME))
print("%d ops in the final graph." % len(tf.get_default_graph().as_graph_def().node))
for op in tf.get_default_graph().get_operations():
print (op.name, op.values())
此網(wǎng)絡圖同上一節(jié)一樣是一個簡單的數(shù)學運算。
不同的是,此程序最后保存為ckpt格式的模型文件:
checkpoint :記錄目錄下所有模型文件列表
ckpt.data :保存模型中每個變量的取值
ckpt.meta :保存整個計算網(wǎng)絡圖的結(jié)構(gòu)
同樣可通過Netron軟件查看ckpt文件中存儲的網(wǎng)絡圖結(jié)構(gòu):
3)將生成的cpkt格式模型文件轉(zhuǎn)換為.pb文件:
import tensorflow as tf
from tensorflow.python.framework import graph_util
#from create_tf_record import *
resize_height = 100 # 指定圖片高度
resize_width = 100 # 指定圖片寬度
def freeze_graph(input_checkpoint, output_graph):
'''
:param input_checkpoint:
:param output_graph: PB 模型保存路徑
:return:
'''
# 檢查目錄下ckpt文件狀態(tài)是否可用
# checkpoint = tf.train.get_checkpoint_state(model_folder)
# 得ckpt文件路徑
# input_checkpoint = checkpoint.model_checkpoint_path
# 指定輸出的節(jié)點名稱,該節(jié)點名稱必須是元模型中存在的節(jié)點
output_node_names = "predictions"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 獲得默認的圖
input_graph_def = graph.as_graph_def() # 返回一個序列化的圖代表當前的圖
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) # 恢復圖并得到數(shù)據(jù)
# 模型持久化,將變量值固定
output_graph_def = graph_util.convert_variables_to_constants(
sess=sess,
# 等于:sess.graph_def
input_graph_def=input_graph_def,
# 如果有多個輸出節(jié)點,以逗號隔開
output_node_names=output_node_names.split(","))
# 保存模型
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString()) # 序列化輸出
# 得到當前圖有幾個操作節(jié)點
print("%d ops in the final graph." % len(output_graph_def.node))
input_checkpoint='easy_model/model.ckpt'
# 輸出pb模型的路徑
out_pb_path="easy_model/frozen_model.pb"
freeze_graph(input_checkpoint,out_pb_path)
這里需要輸入輸出節(jié)點名稱output_node_names= "predictions",
程序運行之后在easy_model文件夾下生成了frozen_model.pb文件,我們通過Netron軟件查看一下frozen_model.pb網(wǎng)絡圖,可以發(fā)現(xiàn)跟第一節(jié)中直接生成的.pb模型文件一樣的:
cpkt文件格式將模型保存為4個文件,pb文件格式為一個。ckpt模型持久化方式將圖結(jié)構(gòu)與權(quán)重參數(shù)分開保存,多了模型更多的細節(jié),適合模型訓練階段;而pb持久化方式完成了從輸入到輸出的前向傳播,完成了端到端的形式,更適合離線使用。
4)最后,我們將.pb文件轉(zhuǎn)換為.tflite:
我們運行此段代碼:
function showSnackbar() {
var $snackbar = $('#snackbar');
$snackbar.addClass('show');
setTimeout(() => {
$snackbar.removeClass('show');
}, 3000);
}
注意這里需要寫上輸入輸出節(jié)點名稱,這個在構(gòu)建網(wǎng)絡模型時,已經(jīng)定義。
運行之后,會報錯:
module 'tensorflow' has no attribute 'lite'
意思是目前的tensorflow版本不支持lite類,沒關(guān)系,我們重新安裝1.14版本的tensorflow即可成功運行。最后生成tflite格式的模型文件:easy_model.lite。
03、運行驗證環(huán)節(jié)
將轉(zhuǎn)換后模型運行到OK1052-C開發(fā)板進行驗證
首先需要將將easy_model.lite轉(zhuǎn)換為二進制數(shù)組的.h文件easy_frozen_0.h,轉(zhuǎn)換完成之后,將其放入OK1052-C用戶資料的,SDK的middlewareeiqtensorflow-liteexampleslabel_image目錄下,
打開SDK中boardsevkbimxrt1050eiq_examplestensorflow_lite_label_image下工程,在文件label_iamge.cpp中做如下修改:
#include "board.h"
#include "pin_mux.h"
#include "clock_config.h"
#include "fsl_debug_console.h"
#include
#include
#include
#include "timer.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/optional_debug_tools.h"
#include "tensorflow/lite/string_util.h"
//#include "Sine_mode.h"
//#include "add_model.h"
#include "easy_frozen_0.h"
int inference_count = 0;
// This is a small number so that it's easy to read the logs
const int kInferencesPerCycle = 30;
const float kXrange = 2.f * 3.14159265359f;
#define LOG(x) std::cout
void RunInference()
{
std::unique_ptr
std::unique_ptr
model = tflite::FlatBufferModel::BuildFromBuffer(mobilenet_model, mobilenet_model_len);
if (!model) {
LOG(FATAL) << "Failed to load modelrn";
exit(-1);
}
model->error_reporter();
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
if (!interpreter) {
LOG(FATAL) << "Failed to construct interpreterrn";
exit(-1);
}
float input = interpreter->inputs()[0];
if (interpreter->AllocateTensors() != kTfLiteOk) {
LOG(FATAL) << "Failed to allocate tensors!rn";
}
while(true)
{
// Calculate an x value to feed into the model. We compare the current
// inference_count to the number of inferences per cycle to determine
// our position within the range of possible x values the model was
// trained on, and use this to calculate a value.
float position = static_cast
static_cast
float x_val = position * kXrange;
float* input_tensor_data = interpreter->typed_tensor
*input_tensor_data = x_val;
// Delay_time(1000);
// Run inference, and report any error
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk)
{
LOG(FATAL) << "Failed to invoke tflite!rn";
return;
}
// Read the predicted y value from the model's output tensor
float* y_val = interpreter->typed_output_tensor
PRINTF("rn x_value: %f, y_value: %f rn", x_val, y_val[0]);
//PRINTF("rn x_value: %d, y_value: %d rn", (int)x_val, (int)y_val[0]);
// Increment the inference_counter, and reset it if we have reached
// the total number per cycle
inference_count += 1;
if (inference_count >= kInferencesPerCycle) inference_count = 0;
}
}
/*
* @brief Application entry point.
*/
int main(void)
{
/* Init board hardware */
BOARD_ConfigMPU();
BOARD_InitPins();
BOARD_InitDEBUG_UARTPins();
BOARD_BootClockRUN();
BOARD_InitDebugConsole();
NVIC_SetPriorityGrouping(3);
InitTimer();
std::cout << "The hello_world demo of TensorFlow Lite modelrn";
RunInference();
std::flush(std::cout);
for (;;) {}
}
此工程運行之后打印信息如下:
可以看到,輸入的值x_value通過模型計算之后得到y(tǒng)_value
使用的就是計算網(wǎng)絡中的公式:
predictions = (input_holder * W1) + B1 + 10
由此可知,我們模型轉(zhuǎn)換成功。