Tensorflow C++ 从训练到部署(2):简单图的保存、读取与 CMake 编译

经过了 上一篇文章,我们已经成功编译了 tensorflow c++ 的系统库文件并且安装到系统目录下了。这里我们将使用这个编译好的库进行基本的 C++ 模型加载执行等操作。
注意,在本篇文章会使用 Tensorflow 的 Python API,因为比较简单,这里不做介绍,安装详见官网教程:
https://www.tensorflow.org/install/

0、系统环境
Ubuntu 16.04
Tensorflow 1.10.1 (安装详见官网,建议使用 pip 方式安装)

1、一个简单网络的保存
只有 c = a * b 的网络:

在这段代码中,我们构建了一个非常简单的“网络”:
\(c = a * b\)
并且给 a 和 b 赋予了初值。这一网络虽然简单,但是更复杂的网络也是类似的道理,我们需要保存的无非是计算图的结构,和图中的参数。这里边我们 tf.train.write_graph 保存的仅仅是图的结构。注意其中每个 placeholder 的 name 非常重要,这是我们后面输入和获取这些值的基础。

在这一例子中,有两点注意:
a)tf.train.write_graph 函数就是指定了保存图的路径,as_text=False 表明用二进制格式保存(默认是文本格式保存)。
b)res = sess.run(c, feed_dict={‘a:0’: 2.0, ‘b:0’: 3.0}) 中,我们定义的 placeholder 的名字是 a 和 b,但是对于 tensorflow 来说 shape=None 的相当于一个 1×1 的向量,所以我们还是要指定一个下标表示 a 和 b 是这个向量的第一个元素,这里 tensorflow 用 a:0 和 b:0 表示。如果 a 和 b 本来就是一个多维向量,那么可以直接取出整个向量。

我们将这段代码保存为 simple/simple_net.py 文件并执行:

正常情况下会在 model 文件夹下保存一个 simple.pb 的文件,并且输出一个测试运行的结果:
res = 6.0

PS:你也可以指定为文本格式保存,通常文本格式我们保存为 .pbtxt 后缀,例如:

保存出来就是这样的,看起来很简单,不过通常我们还是用二进制格式比较高效:

2、使用 Python 读取这个网络
首先我们来看下用 Python 如何读取这个网络并进行计算:

我们将这段代码保存为 simple/load_simple_net.py 文件并执行:

如果运行成功的话,可以看到输出结果:
res = 6.0

3、使用 C++ 读取这个网络
现在我们使用 C++ API 来读取这个图并使用图进行计算:

将这一文件保存为 simple/load_simple_net.cpp。

4、编译运行 Tensorflow C++ API
1)编写 CMakeLists.txt
首先我们编写一个 CMakeLists.txt 来编译这个文件:

将这一文件保存为 CMakeLists.txt。

2)下载 eigen3
为了避免产生问题,我们这里要依赖自己的 eigen3 而不要用 tensorflow 下载的 eigen3,我们下载一个 eigen3 并放在 third_party/eigen3 目录,这里我使用的是 3.3.5 版本:
http://bitbucket.org/eigen/eigen/get/3.3.5.tar.bz2

现在我们将所有文件创建完毕,你的目录结构应该是这样的:
screenshot-from-2018-09-19-17-06-47

3)编译 & 运行
在工程根目录下使用如下命令编译运行:

然后在 build 目录运行:

如果运行成功会显示:
Session created successfully
Load graph protobuf successfully
Add graph to session successfully
Run session successfully
Tensor
output value: 6

到此你已经成功用 C++ API 加载了之前定义的图(尽管只是一个乘法)并利用这个图进行了计算。

后面我们将在此基础上尝试真正的神经网络的运算。

本文完整代码参见 Github:
https://github.com/skylook/tensorflow_cpp

参考文献
[1] https://github.com/formath/tensorflow-predictor-cpp
[2] https://github.com/zhangcliff/tensorflow-c-mnist
[3] https://blog.csdn.net/ztf312/article/details/72859075

About skylook

增强现实、图像识别技术爱好者。
This entry was posted in tensorflow. Bookmark the permalink.

发表评论

电子邮件地址不会被公开。




Optimization WordPress Plugins & Solutions by W3 EDGE