将keras的h5文件转为tensorflow的pb文件

背景

目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。

将keras的h5转为tensorflow加载pb

🎋. 网上各种给的代码乱七八糟,直接用下面这个链接的程序转就好了:https://github.com/amir-abdi/keras_to_tensorflow

🎋. keras中:

1
2
3
4
cnn.save("model.h5")   # 保存模型和权值
cnn.save_weights("model_weights.h5") # 保存权值
cnn = load_model("model.h5") # 导入模型和权值
cnn.load_weights("model_weights.h5") # 导入权值

🎋. 在python中的使用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def recognize(input_1, input_2):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()

with open('model.pb', "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")

# 这里是看pb文件中保存的模型的每一层的名字
tensor_name = [tensor.name for tensor in output_graph_def.node]
print(tensor_name)

with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)

input_x1 = sess.graph.get_tensor_by_name("input_1:0")
input_x2 = sess.graph.get_tensor_by_name("input_2:0")
output = sess.graph.get_tensor_by_name("dense_3/Sigmoid:0")
print(sess.run(output, feed_dict={input_x1:input_1, input_x2:input_2}))

参考代码:https://blog.csdn.net/Butertfly/article/details/80952987