TensorFlow模型保存总结

事先声明,以下大部分内容来源于tensorflow 模型导出总结,并加上个人见解。

tensorflow 1.0 以及2.0 提供了多种不同的模型导出格式,例如说有checkpoint,SavedModel,Frozen GraphDef,Keras model(HDF5) 以及用于移动端,嵌入式的TFLite。

模型导出主要包含了:参数以及网络结构的导出,不同的导出格式可能是分别导出,或者是整合成一个独立的文件。

  • 参数和网络结构分开保存:checkpoint, SavedModel
  • 只保存权重:HDF5(可选)
  • 参数和网络结构保存在一个文件:Frozen GraphDef,HDF5(可选)

在tensorflow 1.0中,可以见下图,主要有三种主要的API:Keras、Estimator以及Legacy即最初的session模型,其中tf.Keras主要保存为HDF5,Estimator保存为SavedModel,而Lagacy主要保存的是Checkpoint,并且可以通过freeze_graph,将模型变量冻结,得到Frozen GradhDef的文件。这三种格式的模型,都可以通过TFLite Converter导出为 .tflite 的模型文件,用于安卓/ios/嵌入式设备的serving。

img

在tensorflow 2.0中,推荐使用SavedModel进行模型的保存,所以keras默认导出格式是SavedModel,也可以通过显性使用 .h5 后缀,使得保存的模型格式为HDF5 。 此外其他low level API,都支持导出为SavedModel格式,以及Concrete Functions。Concrete Function是一个签名函数,有固定格式的输入和输出。 最终转化成Flatbuffer,服务端运行结束。

checkpoint

checkpint(CKPT)的导出是网络结构和参数权重分开保存的。其组成:

1
2
3
4
5
6
7
8
9
checkpoint # 列出该目录下,保存的所有的checkpoint列表,下面有具体的例子
├── events.out.tfevents.1583930869.prod-cloudserver-gpu169 # tensorboad可视化所需文件,可以直观看出模型的结构
'''
model.ckpt-13000表示前缀,代表第13000 global steps时的保存结果,我们在指定checkpoint加载时,也只需要说明前缀即可。
你可以只用 .ckpt-meta 和 .ckpt-data 恢复一个模型
'''
├── model.ckpt-13000.index # 可能是内部需要的某种索引来正确映射前两个文件,它通常不是必需的
├── model.ckpt-13000.data-00000-of-00001 # 包含所有变量的值,没有结构
├── model.ckpt-13000.meta # 包含元图,即计算图的结构,不一定含有变量,就算有变量也没有变量的值(基本上你可以在tensorboard/graph中看到)。

所以一个checkpoint 组成是由两个部分,三个文件组成,其中网络结构部分(meta文件),以及参数部分(参数名:index,参数值:data)

其中checkpoint文件中

1
2
3
4
5
6
model_checkpoint_path: "model.ckpt-16329"
all_model_checkpoint_paths: "model.ckpt-13000"
all_model_checkpoint_paths: "model.ckpt-14000"
all_model_checkpoint_paths: "model.ckpt-15000"
all_model_checkpoint_paths: "model.ckpt-16000"
all_model_checkpoint_paths: "model.ckpt-16329"

使用tensorboard --logdir PATH_TO_CHECKPOINT --host=127.0.0.1: tensorboard 会调用最新的events.out.tfevents.*文件,并生成tensorboard,例如下图:

img

导出成CKPT

  • tensorflow 1.0
1
2
3
4
5
6
# in tensorflow 1.0
saver = tf.train.Saver()
saver.save(sess=session, save_path=args.save_path)

# 若不想保存meta文件
saver2save(sess=session, save_path=args.save_path, write_meta_graph=False)
  • estimator
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# estimator
"""
通过 RunConfig 配置多少时间或者多少个steps 保存一次模型,默认600s 保存一次。
具体参考 https://zhuanlan.zhihu.com/p/112062303
"""
run_config = tf.estimator.RunConfig(
model_dir=FLAGS.output_dir, # 模型保存路径
session_config=config,
save_checkpoints_steps=FLAGS.save_checkpoints_steps, # 多少steps保存一次ckpt
keep_checkpoint_max=1)
estimator = tf.estimator.Estimator(
model_fn=model_fn,
config=run_config,
params=None
)

关于estimator的介绍可以参考:https://zhuanlan.zhihu.com/p/112062303

加载CKPT

  • tf1.0
    ckpt加载的脚本如下,加载完后,session就会是保存的ckpt了。
1
2
3
4
5
# tf1.0
session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess=session, save_path=args.save_path) # 读取保存的模型
  • 对于estimator 会自动load output_dir 中的最新的ckpt。
  • 我们常用的model_file = tf.train.latest_checkpoint(FLAGS.output_dir) 获取最新的ckpt

Meta文件分析

TensorFlow的Meta文件不一定含有变量,也就是使用tf.global_variables()或者tf.trainable_variables()均返回的是空list。

被这个问题困扰了很久,但是一直查不到结果,我就从TensorFlow的源码入手分析一下。我调试了saver.restore的函数,最终定位到了meta_graph.py#L824

相关代码如下:

image-20220410162903108

代码中有一个函数为meta_graph_def.collection_def,若这里面有trainable_variables才可以,但是我遇到了一个meta文件,meta_graph_def.collection_def为空,也就没有了变量。但是这个文件咋来的,还不是很清楚。需要进一步探索。

SavedModel

SavedModel 格式是tensorflow 2.0 推荐的格式,他很好地支持了tf-serving等部署,并且可以简单被python,java等调用。

一个 SavedModel 包含了一个完整的 TensorFlow program, 包含了 weights 以及 计算图 computation. 它不需要原本的模型代码就可以加载所以很容易在 TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub 上部署。

通常SavedModel由以下几个部分组成

1
2
3
4
5
6
├── assets/ # 所需的外部文件,例如说初始化的词汇表文件,一般无
├── assets.extra/ # TensorFlow graph 不需要的文件, 例如说给用户知晓的如何使用SavedModel的信息. Tensorflow 不使用这个目录下的文件。
├── saved_model.pb # 保存的是MetaGraph的网络结构, 或者说是saved_model.pbtxt
├── variables # 参数权重,包含了所有模型的变量(tf.Variable objects)参数
├── variables.data-00000-of-00001
└── variables.index

补充pb格式说明:GraphDef(*.pb)格式文件包含 protobuf 对象序列化后的数据,包含了计算图,可以从中得到所有运算符(operators)的细节,也包含张量(tensors)和 Variables 定义,但不包含 Variable 的值,因此只能从中恢复计算图,但一些训练的权值仍需要从 checkpoint 中恢复。

TensorFlow 一些例程中用到 *.pb 文件作为预训练模型,这和上面 GraphDef 格式稍有不同,属于冻结(Frozen)后的 GraphDef 文件,简称 FrozenGraphDef 格式。这种文件格式不包含 Variables 节点。将 GraphDef 中所有 Variable 节点转换为常量(其值从 checkpoint 获取),就变为 FrozenGraphDef 格式。代码可以参考 tensorflow/python/tools/freeze_graph.py

*.pb 为二进制文件,实际上 protobuf 也支持文本格式(*.pbtxt),但包含权值时文本格式会占用大量磁盘空间,一般不用。

导出为SavedModel

  • tf 1.0 方式
1
2
3
4
5
6
7
8
"""tf1.0"""
x = tf.placeholder(tf.float32, [None, 784], name="myInput")
y = tf.nn.softmax(tf.matmul(x, W) + b, name="myOutput")
tf.saved_model.simple_save(
sess,
export_dir,
inputs={"myInput": x},
outputs={"myOutput": y})

simple_save 是对于普通的tf 模型导出的最简单的方式,只需要补充简单的必要参数,有很多参数被省略,其中被省略的最重要的参数是tag(在下面saved_model.builder.SavedModelBuilder会介绍):tag 是用来区别不同的 MetaGraphDef,这是在加载模型所需要的参数。其默认值是tag_constants.SERVING (“serve”)。对于某些节点,如果没有办法直接加name,那么可以采用 tf.identity, 为节点加名字,例如说CRF的输出,以及使dataset后,无法直接加input的name,都可以采用这个方式:

1
2
def addNameToTensor(someTensor, theName):
return tf.identity(someTensor, name=theName)
  • estimator 方式
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""estimator"""
def serving_input_fn():
label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'label_ids': label_ids,
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids,
})
return input_fn

if do_export:
estimator._export_to_tpu = False
estimator.export_saved_model(Flags.export_dir, serving_input_fn)
  • 保存多个 MetaGraphDef's,使用到了tag
1
2
3
4
5
6
7
8
9
10
11
12
import tensorflow.python.saved_model
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
builder = saved_model.builder.SavedModelBuilder(export_path)

signature = predict_signature_def(inputs={'myInput': x},
outputs={'myOutput': y})
""" using custom tag instead of: tags=[tag_constants.SERVING] """
builder.add_meta_graph_and_variables(sess=sess,
tags=["myTag"],
signature_def_map={'predict': signature})
builder.save()
  • ckpt转SavedModel
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def get_saved_model(bert_config, num_labels, use_one_hot_embeddings):
tf_config = tf.compat.v1.ConfigProto()
tf_config.gpu_options.allow_growth = True

model_file = tf.train.latest_checkpoint(FLAGS.output_dir)
with tf.Graph().as_default(), tf.Session(config=tf_config) as tf_sess:
label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')

loss, per_example_loss, probabilities, predictions = \
create_model(bert_config, False, input_ids, input_mask, segment_ids, label_ids,
num_labels, use_one_hot_embeddings)
saver = tf.train.Saver()
print("restore;{}".format(model_file))
saver.restore(tf_sess, model_file)
tf.saved_model.simple_save(tf_sess,
FLAGS.output_dir,
inputs={
'label_ids': label_ids,
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids,
},
outputs={"probabilities": probabilities})
  • frozen graph to savedModel。注意这个方法我失败了,variables文件夹下面没有东西。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

export_dir = 'inference/pb2saved'
graph_pb = 'inference/robert_tiny_clue/frozen_model.pb'

builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

with tf.gfile.GFile(graph_pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

sigs = {}

with tf.Session(graph=tf.Graph()) as sess:
# name="" is important to ensure we don't get spurious prefixing
tf.import_graph_def(graph_def, name="")
g = tf.get_default_graph()
input_ids = sess.graph.get_tensor_by_name(
"input_ids:0")
input_mask = sess.graph.get_tensor_by_name(
"input_mask:0")
segment_ids = sess.graph.get_tensor_by_name(
"segment_ids:0")
probabilities = g.get_tensor_by_name("loss/pred_prob:0")

sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
tf.saved_model.signature_def_utils.predict_signature_def(
{
"input_ids": input_ids,
"input_mask": input_mask,
"segment_ids": segment_ids
}, {
"probabilities": probabilities
})

builder.add_meta_graph_and_variables(sess,
[tag_constants.SERVING],
signature_def_map=sigs)

builder.save()
  • tf.keras 2.0
1
2
model.save('saved_model/my_model')  
"""saved as SavedModel by default"""

加载SavedModel

对于在java中加载SavedModel,我们首先需要知道我们模型输入和输出,可以通过以下的脚本在terminal中运行 saved_model_cli show --dir SavedModel路径 --tag_set serve --signature_def serving_default 得到类似以下的结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['input_ids'] tensor_info:
dtype: DT_INT32
shape: (-1, 128)
name: input_ids:0
inputs['input_mask'] tensor_info:
dtype: DT_INT32
shape: (-1, 128)
name: input_mask:0
inputs['label_ids'] tensor_info:
dtype: DT_INT32
shape: (-1)
name: label_ids:0
inputs['segment_ids'] tensor_info:
dtype: DT_INT32
shape: (-1, 128)
name: segment_ids:0
The given SavedModel SignatureDef contains the following output(s):
outputs['probabilities'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 7)
name: loss/pred_prob:0
Method name is: tensorflow/serving/predict

首先我们可以看到有inputs,以及outputs,分别是一个key为string,value为tensor的字典,每个tensor都有各自的名字。

当然我们可以通过saved_model_cli show --dir SavedModel路径 --all得到所有的结果,包含了Concrete Functions

Python 加载

我们有常见两种方式可以加载savedModel,一种是采用 tf.contrib.predictor.from_saved_model 传入predictor模型的inputs dict,然后得到 outputs dict。 一种是直接类似tf1.0的方式,采用 tf.saved_model.loader.load, feed tensor然后fetch tensor。

  • 采用predictor

    采用predictor时, 需要传入的字典名字用的是 inputs的key,而不是tensor的names

1
2
3
4
5
6
7
8
predict_fn = tf.contrib.predictor.from_saved_model(args_in_use.model)
# 其中feature.xxxxxx 应该是需要feed_dict的数据
prediction = predict_fn({
"input_ids": [feature.input_ids],
"input_mask": [feature.input_mask],
"segment_ids": [feature.segment_ids],
})
probabilities = prediction["probabilities"]
  • tf 1.0 采用 loader

    采用loader的方式是采用 session 的feed_dict 方式,该方式feed的是tenor的names,fetch的同样也是tensor 的names。其中feed_dict的key 可以直接是tensor的name,或者是采用 sess.graph.get_tensor_by_name(TENSOR_NAME) 得到的tensor。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["serve"], export_path)
graph = tf.get_default_graph()
feed_dict = {"input_ids_1:0": [feature.input_ids],
"input_mask_1:0": [feature.input_mask],
"segment_ids_1:0": [feature.segment_ids]}
"""
# alternative way
feed_dict = {sess.graph.get_tensor_by_name("input_ids_1:0"):
[feature.input_ids],
sess.graph.get_tensor_by_name("input_mask_1:0"):
[feature.input_mask],
sess.graph.get_tensor_by_name("segment_ids_1:0"):
[feature.segment_ids]}
"""
sess.run('loss/pred_prob:0',
feed_dict=feed_dict)
  • tf.keras 2.0
1
new_model = tf.keras.models.load_model('saved_model/my_model')

JAVA 加载

注意 java加载的时候,如果遇到Op not defined 的错误,是需要匹配模型训练python的tensorflow版本以及java的tensorflow版本的。

所以我们知道我们在tag-set 为serve的tag下,有4个inputs tensors,name分别为input_ids:0, input_mask:0, label_ids:0, segment_ids:0, 输出为1个,name是 loss/pred_prob:0。并且我们知道这些tensor的类型。

所以我们可以通过下面的java代码,进行加载,获得结果。注意我们需要传入的name中不需要传入:0

1
2
3
4
5
6
7
8
9
10
import org.tensorflow.*
SavedModelBundle savedModelBundle = SavedModelBundle.load("./export_path", "serve");
Graph graph = savedModelBundle.graph();

Tensor tensor = this.savedModelBundle.session().runner()
.feed("input_ids", inputIdTensor)
.feed("input_mask", inputMaskTensor)
.feed("segment_ids", inputSegmentTensor)
.fetch("loss/pred_prob")
.run().get(0);

CLI 加载

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
$ saved_model_cli show --dir export/1524906774 \
--tag_set serve --signature_def serving_default
The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_STRING
shape: (-1)
The given SavedModel SignatureDef contains the following output(s):
outputs['classes'] tensor_info:
dtype: DT_STRING
shape: (-1, 3)
outputs['scores'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 3)
Method name is: tensorflow/serving/classify

$ saved_model_cli run --dir export/1524906774 \
--tag_set serve --signature_def serving_default \
--input_examples 'inputs=[{"SepalLength":[5.1],"SepalWidth":[3.3],"PetalLength":[1.7],"PetalWidth":[0.5]}]'
Result for output key classes:
[[b'0' b'1' b'2']]
Result for output key scores:
[[9.9919027e-01 8.0969761e-04 1.2872645e-09]]

Frozen Graph

frozen Graphdef 将tensorflow导出的模型的权重都freeze住,使得其都变为常量。并且模型参数和网络结构保存在同一个文件中,可以在python以及java中自由调用。

导出为pb

python

  • 采用session方式保存frozen graph
1
2
3
4
5
6
7
8
"""tf1.0"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants

output_graph_def = convert_variables_to_constants(
session,
session.graph_def,
output_node_names=['loss/pred_prob'])
tf.train.write_graph(output_graph_def, args.export_dir, args.model_name, as_text=False)
  • 采用ckpt 转换成frozen graph
    以下采用bert tensorflow模型做演示
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
NB:首先我们要在create_model() 函数中,为我们需要的输出节点取个名字,
比如说我们要: probabilities = tf.nn.softmax(logits, axis=-1, name='pred_prob')
"""

def get_frozen_model(bert_config, num_labels, use_one_hot_embeddings):
tf_config = tf.compat.v1.ConfigProto()
tf_config.gpu_options.allow_growth = True
output_node_names = ['loss/pred_prob']

model_file = tf.train.latest_checkpoint(FLAGS.output_dir)
with tf.Graph().as_default(), tf.Session(config=tf_config) as tf_sess:
label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')

create_model(bert_config, False, input_ids, input_mask, segment_ids, label_ids,
num_labels, use_one_hot_embeddings)
saver = tf.train.Saver()
print("restore;{}".format(model_file))
saver.restore(tf_sess, model_file)
tmp_g = tf_sess.graph.as_graph_def()
if FLAGS.use_opt:
input_tensors = [input_ids, input_mask, segment_ids]
dtypes = [n.dtype for n in input_tensors]
print('optimize...')
tmp_g = optimize_for_inference(tmp_g,
[n.name[:-2] for n in input_tensors],
output_node_names,
[dtype.as_datatype_enum for dtype in dtypes],
False)
print('freeze...')
frozen_graph = tf.graph_util.convert_variables_to_constants(tf_sess,
tmp_g, output_node_names)
out_graph_path = os.path.join(FLAGS.output_dir, "frozen_model.pb")
with tf.io.gfile.GFile(out_graph_path, "wb") as f:
f.write(frozen_graph.SerializeToString())
print(f'pb file saved in {out_graph_path}')
  • 采用savedModel 转换成 frozen graph
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from tensorflow.python.tools import freeze_graph
from tensorflow.python.saved_model import tag_constants

input_saved_model_dir = "./1583934987/"
output_node_names = "loss/pred_prob"
input_binary = False
input_saver_def_path = False
restore_op_name = None
filename_tensor_name = None
clear_devices = False
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = tag_constants.SERVING
output_graph_filename='frozen_graph.pb'

freeze_graph.freeze_graph(input_graph_filename,
input_saver_def_path,
input_binary,
checkpoint_path,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph_filename,
clear_devices,
"", "", "",
input_meta_graph,
input_saved_model_dir,
saved_model_tags)
  • HDF5 to pb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.

Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = tf.graph_util.convert_variables_to_constants(
session, input_graph_def, output_names, freeze_var_names)
return frozen_graph

frozen_graph = freeze_session(K.get_session(),
output_names=[out.op.name for out in model.outputs])

tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False)

CLI转换工具

以下的工具可以快速进行ckpt到pb的转换,但是不能再原本的基础上增加tensor 的名字。

1
2
3
4
5
6
freeze_graph --input_checkpoint model.ckpt-16329 \
--output_graph 0316_roberta.pb \
--output_node_names loss/pred_prob \
--checkpoint_version 1 \
--input_meta_graph model.ckpt-16329.meta \
--input_binary true

模型加载

获取frozen graph 中节点名字的脚本如下,但是一般来说,我们的inputs都是我们定义好的placeholders。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import tensorflow as tf

def printTensors(pb_file):

"""read pb into graph_def"""
with tf.gfile.GFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

"""import graph_def"""
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)

"""print operations"""
for op in graph.get_operations():
print(op.name)

printTensors("path-to-my-pbfile.pb")

得到类似如下的结果

1
2
3
4
5
import/input_ids:0
import/input_mask:0
import/segment_ids:0
...
import/loss/pred_prob:0

当我们知道我们要feed以及fetch的节点名称之后,我们就可以通过python/java加载了。
跟savedModel一样,对于某些节点,如果没有办法直接加name,那么可以采用 tf.identity, 为节点加名字,例如说CRF的输出,以及使用dataset后,无法直接加input的name,都可以采用这个方式

1
2
def addNameToTensor(someTensor, theName):
return tf.identity(someTensor, name=theName)

Python 加载

我们保存完frozen graph 模型后,假设我们的模型包含以下的tensors:

那么我们通过python加载的代码如下, 采用的是session feed和fetch的方式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()

"""
load pb model
"""
with open(args_in_use.model, 'rb') as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name='') #name是必须的
"""
enter a text and predict
"""
with tf.Session() as sess:
tf.global_variables_initializer().run()
input_ids = sess.graph.get_tensor_by_name(
"input_ids:0")
input_mask = sess.graph.get_tensor_by_name(
"input_mask:0")
segment_ids = sess.graph.get_tensor_by_name(
"segment_ids:0")
output = "loss/pred_prob:0"

feed_dict = {
input_ids: [feature.input_ids],
input_mask: [feature.input_mask],
segment_ids: [feature.segment_ids],
}
# 也可以直接使用
# feed_dict = {
# "input_ids:0": [feature.input_ids],
# "input_mask:0": [feature.input_mask],
# "segment_ids:0": [feature.segment_ids],
# }
y_pred_cls = sess.run(output, feed_dict=feed_dict)

Java 加载

对于frozen graph,我们加载的方式和savedModel很类似,首先我们需要先启动一个session,然后在启动一个runner(),然后再feed模型的输入,以及fetch模型的输出。

注意 java加载的时候,如果遇到Op not defined 的错误,是需要匹配模型训练python的tensorflow版本以及java的tensorflow版本的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// TensorUtil.class
public static Session generateSession(String modelPath) throws IOException {
Preconditions.checkNotNull(modelPath);
byte[] graphDef = ByteStreams.toByteArray(TensorUtil.class.getResourceAsStream(modelPath));
LOGGER.info("Graph Def Length: {}", graphDef.length);
Graph graph = new Graph();
graph.importGraphDef(graphDef);
return new Session(graph);
}

// model.class
this.session = TensorUtil.generateSession(modelPath);

Tensor tensor = this.session.runner()
.feed("input_ids", inputIdTensor)
.feed("input_mask", inputMaskTensor)
.feed("segment_ids", inputSegmentTensor)
.fetch("loss/pred_prob")
.run().get(0);

HDF5

HDF5 是keras or tf.keras 特有的存储格式。

HDF5导出

  • 导出整个模型
1
2
"""默认1.0 是HDF5,但是2.0中,是SavedModel,所以需要显性地指定`.h5`后缀"""
model.save('my_model.h5')
  • 导出模型weights
1
2
"""keras 1.0"""
model.save_weights('my_model_weights.h5')

HDF5加载

  • 加载整个模型(无自定义部分)

    • keras1.0
1
2
3
"""keras 1.0"""
from keras.models import load_model
model = load_model(model_path)
    • keras2.0
1
2
"""keras 2.0"""
new_model = tf.keras.models.load_model('my_model.h5')
  • 加载整个模型(含自定义部分)
    对于有自定义layers的或者实现的模型加载,需要增加dependencies 的映射字典,例如下面的例子:

    • keras1.0
1
2
dependencies = {'MyLayer': MyLayer(), 'auc': auc, 'log_loss': log_loss}
model = load_model(model_path, custom_objects=dependencies, compile=False)
    • keras 2.0
1
2
3
4
5
6
7
8
9
"""
To save custom objects to HDF5, you must do the following:

1. Define a get_config method in your object, and optionally a from_config classmethod.
get_config(self) returns a JSON-serializable dictionary of parameters needed to recreate the object.
from_config(cls, config) uses the returned config from get_config to create a new object. By default, this function will use the config as initialization kwargs (return cls(**config)).

2. Pass the object to the custom_objects argument when loading the model. The argument must be a dictionary mapping the string class name to the Python class. E.g. tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})
"""
  • 加载模型权重
    假设你有了相同的模型构建了,那么直接运行下面的代码,加载模型
1
model.load_weights('my_model_weights.h5')

如果你想要做transfer learning,即从其他的已保存的模型中加载部分的模型参数权重,自己目前的模型结构与保存的模型不同,可以通过参数的名字进行加载,加上 by_name=True

1
model.load_weights('my_model_weights.h5', by_name=True)

tfLite

TFlite转换

  • savedModel to TFLite
1
2
3
4
5
6
7
"""
--saved_model_dir: Type: string. Specifies the full path to the directory containing the SavedModel generated in 1.X or 2.X.
--output_file: Type: string. Specifies the full path of the output file.
"""
tflite_convert \
--saved_model_dir=1583934987 \
--output_file=rbt.tflite
  • frozen graph to TFLite
1
2
3
4
5
6
7
tflite_convert --graph_def_file albert_tiny_zh.pb \
--input_arrays 'input_ids,input_masks,segment_ids' \
--output_arrays 'finetune_mrc/add, finetune_mrc/add_1'\
--input_shapes 1,512:1,512:1,512 \
--output_file saved_model.tflite \
--enable_v1_converter \
--experimental_new_converter
  • HDF5 to TFLite
1
2
3
4
5
#--keras_model_file. Type: string. Specifies the full path of the HDF5 file containing the tf.keras model generated in 1.X or 2.X.   
#--output_file: Type: string. Specifies the full path of the output file.
tflite_convert \
--keras_model_file=h5_dir/ \
--output_file=rbt.tflite

另外,补充一个TFlite转frozen graph:

tensorflow在早期提供了转换工具(1.9版本后的tensorflow没有再提到这个功能了),具体操作可以看这里

有的模型TOCO工具可能会转换失败,可以参考这个链接

TFLite 加载

参考 https://www.tensorflow.org/lite/guide/inference
参考 https://github.com/tensorflow/t

这里介绍一个Python的加载。

  • 当从SavedModel转换得到,并且含有SignatureDef时:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class TestModel(tf.Module):
def __init__(self):
super(TestModel, self).__init__()

@tf.function(input_signature=[tf.TensorSpec(shape=[1, 10], dtype=tf.float32)])
def add(self, x):
'''
Simple method that accepts single input 'x' and returns 'x' + 4.
'''
# Name the output 'result' for convenience.
return {'result' : x + 4}


SAVED_MODEL_PATH = 'content/saved_models/test_variable'
TFLITE_FILE_PATH = 'content/test_variable.tflite'

# Save the model
module = TestModel()
# You can omit the signatures argument and a default signature name will be
# created with name 'serving_default'.
tf.saved_model.save(
module, SAVED_MODEL_PATH,
signatures={'my_signature':module.add.get_concrete_function()})

# Convert the model using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
tflite_model = converter.convert()
with open(TFLITE_FILE_PATH, 'wb') as f:
f.write(tflite_model)

# Load the TFLite model in TFLite Interpreter
interpreter = tf.lite.Interpreter(TFLITE_FILE_PATH)
# There is only 1 signature defined in the model,
# so it will return it by default.
# If there are multiple signatures then we can pass the name.
my_signature = interpreter.get_signature_runner()

# my_signature is callable with input as arguments.
output = my_signature(x=tf.constant([1.0], shape=(1,10), dtype=tf.float32))
# 'output' is dictionary with all outputs from the inference.
# In this case we have single output 'result'.
print(output['result'])
  • 当没有SignatureDef时
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
import tensorflow as tf

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

参考

TensorFlow:.ckpt文件与.ckpt.meta和.ckpt.index以及.pb文件之间的关系是什么?
TF的三种模型的保存与加载方式
TensorFlow 到底有几种模型格式?
tensorflow 模型导出总结

------ 本文结束------
坚持原创技术分享,您的支持将鼓励我继续创作!

欢迎关注我的其它发布渠道