背景
目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。
将keras的h5转为tensorflow加载pb
🎋. 网上各种给的代码乱七八糟,直接用下面这个链接的程序转就好了:https://github.com/amir-abdi/keras_to_tensorflow
🎋. keras中:
1 2 3 4
| cnn.save("model.h5") cnn.save_weights("model_weights.h5") cnn = load_model("model.h5") cnn.load_weights("model_weights.h5")
|
🎋. 在python中的使用:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| def recognize(input_1, input_2): with tf.Graph().as_default(): output_graph_def = tf.GraphDef()
with open('model.pb', "rb") as f: output_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(output_graph_def, name="")
tensor_name = [tensor.name for tensor in output_graph_def.node] print(tensor_name)
with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init)
input_x1 = sess.graph.get_tensor_by_name("input_1:0") input_x2 = sess.graph.get_tensor_by_name("input_2:0") output = sess.graph.get_tensor_by_name("dense_3/Sigmoid:0") print(sess.run(output, feed_dict={input_x1:input_1, input_x2:input_2}))
|
参考代码:https://blog.csdn.net/Butertfly/article/details/80952987