Tensorflow C API 从训练到部署:使用 C API 进行预测和部署
前述博文 Tensorflow C++ 从训练到部署(2):简单图的保存、读取与 CMake 编译 和 Tensorflow C++ 从训练到部署(3):使用 Keras 训练和部署 CNN 使用 Tensorflow/Keras 的 Python API 进行训练,并使用 C++ API 进行了预测。由于 C++ API 需要编译 Tensorflow 源码,还是比较麻烦的。而 Tensorflow 官方提供了 C API 编译好的库文件,相对部署上比较容易(直接复制库文件到自己的工程即可),本文将介绍使用 C API 进行预测的方法。对于 Python 训练部分,与前述文章相同不做赘述。
0、系统环境
Ubuntu 16.04
Tensorflow 1.12.0
1、安装依赖
1、GPU 支持安装(可选)
CUDA 9.0
cnDNN 7.x
2、Tensorflow 1.12.0
下载地址:
https://www.tensorflow.org/install/lang_c
其中 1.12.0 的下载地址如下(我这里提供了包含TX2 aarch64在内的几个版本):
TensorFlow C library | CUDA | cuDNN | URL |
---|---|---|---|
Linux x86_64 CPU | x | x | https://pan.baidu.com/s/1FDdXCgtJJlDJP8ziDs6dow |
Linux x86_64 GPU | 9.0 | 7.x | https://pan.baidu.com/s/1qxDntkQ-rcgvp1xxrSKW0w |
macOS CPU | x | x | https://pan.baidu.com/s/1F6NdNtCxg11P_EpEdsqttA |
Linux aarch64 GPU (TX2) | 9.0 | 7.0.5 | https://pan.baidu.com/s/1mI76203wY9Nd5US4sH5-pg |
将库解压到 third_party/libtensorflow 目录。
如果上面的版本都不符合你的需求,你可以参照这篇文章编译你需要的版本。
2、TFUtils 工具类
为了简便起见,我们首先将常用的 C API 封装为
1)文件 utils/TFUtils.hpp:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | // Licensed under the MIT License <http://opensource.org/licenses/MIT>. // Copyright (c) 2018 Liu Xiao <liuxiao@foxmail.com> and Daniil Goncharov <neargye@gmail.com>. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. #pragma once #if defined(_MSC_VER) # if !defined(COMPILER_MSVC) # define COMPILER_MSVC // Set MSVC visibility of exported symbols in the shared library. # endif # pragma warning(push) # pragma warning(disable : 4190) #endif #include <tensorflow/c/c_api.h> #include <iostream> #include <cstddef> #include <cstdint> #include <vector> #include <string> class TFUtils { public: enum STATUS { SUCCESS = 0, SESSION_CREATE_FAILED = 1, MODEL_LOAD_FAILED = 2, FAILED_RUN_SESSION = 3, MODEL_NOT_LOADED = 4, }; TFUtils(); STATUS LoadModel(std::string model_file); ~TFUtils(); TF_Output GetOperationByName(std::string name, int idx); STATUS RunSession(const std::vector<TF_Output>& inputs, const std::vector<TF_Tensor*>& input_tensors, const std::vector<TF_Output>& outputs, std::vector<TF_Tensor*>& output_tensors); // Static functions template <typename T> static TF_Tensor* CreateTensor(TF_DataType data_type, const std::vector<std::int64_t>& dims, const std::vector<T>& data) { return CreateTensor(data_type, dims.data(), dims.size(), data.data(), data.size() * sizeof(T)); } static void DeleteTensor(TF_Tensor* tensor); static void DeleteTensors(const std::vector<TF_Tensor*>& tensors); template <typename T> static std::vector<std::vector<T>> GetTensorsData(const std::vector<TF_Tensor*>& tensors) { std::vector<std::vector<T>> data; data.reserve(tensors.size()); for (const auto t : tensors) { data.push_back(GetTensorData<T>(t)); } return data; } static TF_Tensor* CreateTensor(TF_DataType data_type, const std::int64_t* dims, std::size_t num_dims, const void* data, std::size_t len); template <typename T> static std::vector<T> GetTensorData(const TF_Tensor* tensor) { const auto data = static_cast<T*>(TF_TensorData(tensor)); if (data == nullptr) { return {}; } return {data, data + (TF_TensorByteSize(tensor) / TF_DataTypeSize(TF_TensorType(tensor)))}; } // STATUS GetErrorCode(); static void PrinStatus(STATUS status); private: TF_Graph* graph_def; TF_Session* sess; STATUS init_error_code; private: TF_Graph* LoadGraphDef(const char* file); TF_Session* CreateSession(TF_Graph* graph); bool CloseAndDeleteSession(TF_Session* sess); bool RunSession(TF_Session* sess, const TF_Output* inputs, TF_Tensor* const* input_tensors, std::size_t ninputs, const TF_Output* outputs, TF_Tensor** output_tensors, std::size_t noutputs); bool RunSession(TF_Session* sess, const std::vector<TF_Output>& inputs, const std::vector<TF_Tensor*>& input_tensors, const std::vector<TF_Output>& outputs, std::vector<TF_Tensor*>& output_tensors); }; // End class TFUtils #if defined(_MSC_VER) # pragma warning(pop) #endif |
2)文件 utils/TFUtils.cpp:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 | // Licensed under the MIT License <http://opensource.org/licenses/MIT>. // Copyright (c) 2018 Liu Xiao <liuxiao@foxmail.com> and Daniil Goncharov <neargye@gmail.com>. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. #if defined(_MSC_VER) # pragma warning(push) # pragma warning(disable : 4996) #endif #include "TFUtils.hpp" #include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <algorithm> #include <string> // Public functions TFUtils::TFUtils() { init_error_code = MODEL_NOT_LOADED; } // Public functions TFUtils::STATUS TFUtils::LoadModel(std::string model_file) { // Load graph graph_def = LoadGraphDef(model_file.c_str()); if(graph_def == nullptr){ std::cerr << "loading model failed ......" << std::endl; init_error_code = MODEL_LOAD_FAILED; return MODEL_LOAD_FAILED; } // Create session sess = CreateSession(graph_def); if(sess == nullptr){ init_error_code = SESSION_CREATE_FAILED; std::cerr << "create sess failed ......" << std::endl; return SESSION_CREATE_FAILED; } init_error_code = SUCCESS; return init_error_code; } TFUtils::~TFUtils() { if (sess) CloseAndDeleteSession(sess); if (graph_def) TF_DeleteGraph(graph_def); } TF_Output TFUtils::GetOperationByName(std::string name, int idx) { return {TF_GraphOperationByName(graph_def, name.c_str()), idx}; } TFUtils::STATUS TFUtils::RunSession(const std::vector<TF_Output>& inputs, const std::vector<TF_Tensor*>& input_tensors, const std::vector<TF_Output>& outputs, std::vector<TF_Tensor*>& output_tensors) { if (init_error_code != SUCCESS) return init_error_code; bool run_ret = RunSession(sess, inputs, input_tensors, outputs, output_tensors); if (run_ret == false) return FAILED_RUN_SESSION; return SUCCESS; } void TFUtils::PrinStatus(STATUS status) { switch(status) { case SUCCESS: std::cout << "status = SUCCESS" << std::endl; break; case SESSION_CREATE_FAILED: std::cout << "status = SESSION_CREATE_FAILED" << std::endl; break; case MODEL_LOAD_FAILED: std::cout << "status = MODEL_LOAD_FAILED" << std::endl; break; case FAILED_RUN_SESSION: std::cout << "status = FAILED_RUN_SESSION" << std::endl; break; case MODEL_NOT_LOADED: std::cout << "status = MODEL_NOT_LOADED" << std::endl; break; default: std::cout << "status = NOT FOUND" << std::endl; } } // Static functions static void DeallocateBuffer(void* data, size_t) { std::free(data); } static TF_Buffer* ReadBufferFromFile(const char* file) { const auto f = std::fopen(file, "rb"); if (f == nullptr) { return nullptr; } std::fseek(f, 0, SEEK_END); const auto fsize = ftell(f); std::fseek(f, 0, SEEK_SET); if (fsize < 1) { std::fclose(f); return nullptr; } const auto data = std::malloc(fsize); std::fread(data, fsize, 1, f); std::fclose(f); TF_Buffer* buf = TF_NewBuffer(); buf->data = data; buf->length = fsize; buf->data_deallocator = DeallocateBuffer; return buf; } // Private functions TF_Graph* TFUtils::LoadGraphDef(const char* file) { if (file == nullptr) { return nullptr; } TF_Buffer* buffer = ReadBufferFromFile(file); if (buffer == nullptr) { return nullptr; } TF_Graph* graph = TF_NewGraph(); TF_Status* status = TF_NewStatus(); TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); TF_GraphImportGraphDef(graph, buffer, opts, status); TF_DeleteImportGraphDefOptions(opts); TF_DeleteBuffer(buffer); if (TF_GetCode(status) != TF_OK) { TF_DeleteGraph(graph); graph = nullptr; } TF_DeleteStatus(status); return graph; } TF_Session* TFUtils::CreateSession(TF_Graph* graph) { TF_Status* status = TF_NewStatus(); TF_SessionOptions* options = TF_NewSessionOptions(); TF_Session* sess = TF_NewSession(graph, options, status); TF_DeleteSessionOptions(options); if (TF_GetCode(status) != TF_OK) { TF_DeleteStatus(status); return nullptr; } return sess; } bool TFUtils::CloseAndDeleteSession(TF_Session* sess) { TF_Status* status = TF_NewStatus(); TF_CloseSession(sess, status); if (TF_GetCode(status) != TF_OK) { TF_CloseSession(sess, status); TF_DeleteSession(sess, status); TF_DeleteStatus(status); return false; } TF_DeleteSession(sess, status); if (TF_GetCode(status) != TF_OK) { TF_DeleteStatus(status); return false; } TF_DeleteStatus(status); return true; } bool TFUtils::RunSession(TF_Session* sess, const TF_Output* inputs, TF_Tensor* const* input_tensors, std::size_t ninputs, const TF_Output* outputs, TF_Tensor** output_tensors, std::size_t noutputs) { if (sess == nullptr || inputs == nullptr || input_tensors == nullptr || outputs == nullptr || output_tensors == nullptr) { return false; } TF_Status* status = TF_NewStatus(); TF_SessionRun(sess, nullptr, // Run options. inputs, input_tensors, static_cast<int>(ninputs), // Input tensors, input tensor values, number of inputs. outputs, output_tensors, static_cast<int>(noutputs), // Output tensors, output tensor values, number of outputs. nullptr, 0, // Target operations, number of targets. nullptr, // Run metadata. status // Output status. ); if (TF_GetCode(status) != TF_OK) { TF_DeleteStatus(status); return false; } TF_DeleteStatus(status); return true; } bool TFUtils::RunSession(TF_Session* sess, const std::vector<TF_Output>& inputs, const std::vector<TF_Tensor*>& input_tensors, const std::vector<TF_Output>& outputs, std::vector<TF_Tensor*>& output_tensors) { return RunSession(sess, inputs.data(), input_tensors.data(), input_tensors.size(), outputs.data(), output_tensors.data(), output_tensors.size()); } TF_Tensor* TFUtils::CreateTensor(TF_DataType data_type, const std::int64_t* dims, std::size_t num_dims, const void* data, std::size_t len) { if (dims == nullptr || data == nullptr) { return nullptr; } TF_Tensor* tensor = TF_AllocateTensor(data_type, dims, static_cast<int>(num_dims), len); if (tensor == nullptr) { return nullptr; } void* tensor_data = TF_TensorData(tensor); if (tensor_data == nullptr) { TF_DeleteTensor(tensor); return nullptr; } std::memcpy(tensor_data, data, std::min(len, TF_TensorByteSize(tensor))); return tensor; } void TFUtils::DeleteTensor(TF_Tensor* tensor) { if (tensor == nullptr) { return; } TF_DeleteTensor(tensor); } void TFUtils::DeleteTensors(const std::vector<TF_Tensor*>& tensors) { for (auto t : tensors) { TF_DeleteTensor(t); } } #if defined(_MSC_VER) # pragma warning(pop) #endif |
3、简单图的读取与预测
在前述文章 Tensorflow C++ 从训练到部署(2):简单图的保存、读取与 CMake 编译 中我们已经介绍了一个 c=a*b 的简单“网络”是如何计算的。
其中 Python 构建网络和预测部分就不重复了,详见该文所述。这里直接给出 C API 的代码:
文件名:simple/load_simple_net_c_api.cc
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | // Licensed under the MIT License <http://opensource.org/licenses/MIT>. // Copyright (c) 2018 Liu Xiao <liuxiao@foxmail.com> and Daniil Goncharov <neargye@gmail.com>. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. #include "../utils/TFUtils.hpp" #include <iostream> #include <vector> int main(int argc, char* argv[]) { if (argc != 2) { std::cerr << std::endl << "Usage: ./project path_to_graph.pb" << std::endl; return 1; } std::string graph_path = argv[1]; // TFUtils init TFUtils TFU; TFUtils::STATUS status = TFU.LoadModel(graph_path); if (status != TFUtils::SUCCESS) { std::cerr << "Can't load graph" << std::endl; return 1; } // Input Tensor Create const std::vector<std::int64_t> input_a_dims = {1, 1}; const std::vector<float> input_a_vals = {2.0}; const std::vector<std::int64_t> input_b_dims = {1, 1}; const std::vector<float> input_b_vals = {3.0}; const std::vector<TF_Output> input_ops = {TFU.GetOperationByName("a", 0), TFU.GetOperationByName("b", 0)}; const std::vector<TF_Tensor*> input_tensors = {TFUtils::CreateTensor(TF_FLOAT, input_a_dims, input_a_vals), TFUtils::CreateTensor(TF_FLOAT, input_b_dims, input_b_vals)}; // Output Tensor Create const std::vector<TF_Output> output_ops = {TFU.GetOperationByName("c", 0)}; std::vector<TF_Tensor*> output_tensors = {nullptr}; status = TFU.RunSession(input_ops, input_tensors, output_ops, output_tensors); TFUtils::PrinStatus(status); if (status == TFUtils::SUCCESS) { const std::vector<std::vector<float>> data = TFUtils::GetTensorsData<float>(output_tensors); const std::vector<float> result = data[0]; std::cout << "Output value: " << result[0] << std::endl; } else { std::cout << "Error run session"; return 2; } TFUtils::DeleteTensors(input_tensors); TFUtils::DeleteTensors(output_tensors); return 0; } |
简单解释一下:
1 | TFUtils::STATUS status = TFU.LoadModel(graph_path); |
这一行是加载 pb 文件。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | // Input Tensor Create const std::vector<std::int64_t> input_a_dims = {1, 1}; const std::vector<float> input_a_vals = {2.0}; const std::vector<std::int64_t> input_b_dims = {1, 1}; const std::vector<float> input_b_vals = {3.0}; const std::vector<TF_Output> input_ops = {TFU.GetOperationByName("a", 0), TFU.GetOperationByName("b", 0)}; const std::vector<TF_Tensor*> input_tensors = {TFUtils::CreateTensor(TF_FLOAT, input_a_dims, input_a_vals), TFUtils::CreateTensor(TF_FLOAT, input_b_dims, input_b_vals)}; // Output Tensor Create const std::vector<TF_Output> output_ops = {TFU.GetOperationByName("c", 0)}; std::vector<TF_Tensor*> output_tensors = {nullptr} |
这一段是创建两个输入 tensor 以及输入的 ops。注意这里的 CreateTensor 在后面都需要调用 DeleteTensors 进行内存释放。输出的 tensors 还没创建先定义为 nullptr。
1 2 | status = TFU.RunSession(input_ops, input_tensors, output_ops, output_tensors); |
这一行是运行网络。
1 2 | const std::vector<std::vector<float>> data = TFUtils::GetTensorsData<float>(output_tensors); const std::vector<float> result = data[0]; |
这两行是从输出的 output_tensors 读取数据到一个二维vector const std::vector
编译运行这一文件,如果没有问题则会得到如下输出:
1 2 | status = SUCCESS Output value: 6 |
4、CNN的读取与预测
与刚才小节3相似,CNN网络也是一样的流程,还是以最基本的 fashion_mnist 为例,该网络的训练和保存流程请参考之前的文章。这里我们仅介绍 C API 进行预测的部分。由于我们这里需要读取一幅图并转化成 Tensor 输入网络,我们构造一个简单的函数 Mat2Tensor 实现这一转换:
1)Met2Tensor 部分文件:fashion_mnist/utils/mat2tensor_c_cpi.h
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | // Licensed under the MIT License <http://opensource.org/licenses/MIT>. // Copyright (c) 2018 Liu Xiao <liuxiao@foxmail.com>. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. #ifndef TENSORFLOW_CPP_MAT2TENSOR_C_H #define TENSORFLOW_CPP_MAT2TENSOR_C_H #include <tensorflow/c/c_api.h> #include <cstdlib> #include <iostream> #include <vector> #include "opencv2/core/core.hpp" TF_Tensor* Mat2Tensor(cv::Mat &img, float normal = 1/255.0) { const std::vector<std::int64_t> input_dims = {1, img.size().height, img.size().width, img.channels()}; // Convert to float 32 and do normalize ops cv::Mat fake_mat(img.rows, img.cols, CV_32FC(img.channels())); img.convertTo(fake_mat, CV_32FC(img.channels())); fake_mat *= normal; TF_Tensor* image_input = TFUtils::CreateTensor(TF_FLOAT, input_dims.data(), input_dims.size(), fake_mat.data, (fake_mat.size().height * fake_mat.size().width * fake_mat.channels() * sizeof(float))); return image_input; } #endif //TENSORFLOW_CPP_MAT2TENSOR_C_H |
2)网络读取与预测,这部分与刚才的小节3基本一样,就不做解释了:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | // Licensed under the MIT License <http://opensource.org/licenses/MIT>. // Copyright (c) 2018 Liu Xiao <liuxiao@foxmail.com> and Daniil Goncharov <neargye@gmail.com>. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. #include "../utils/TFUtils.hpp" #include "utils/mat2tensor_c_cpi.h" #include <iostream> #include <vector> // OpenCV #include <opencv2/core/core.hpp> #include <opencv2/highgui/highgui.hpp> //std::string class_names[10] = {'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'}; std::string class_names[] = {"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"}; int ArgMax(const std::vector<float> result); int main(int argc, char* argv[]) { if (argc != 3) { std::cerr << std::endl << "Usage: ./project path_to_graph.pb path_to_image.png" << std::endl; return 1; } // Load graph std::string graph_path = argv[1]; // TFUtils init TFUtils TFU; TFUtils::STATUS status = TFU.LoadModel(graph_path); if (status != TFUtils::SUCCESS) { std::cerr << "Can't load graph" << std::endl; return 1; } // Load image and convert to tensor std::string image_path = argv[2]; cv::Mat image = cv::imread(image_path, CV_LOAD_IMAGE_GRAYSCALE); const std::vector<std::int64_t> input_dims = {1, image.size().height, image.size().width, image.channels()}; TF_Tensor* input_image = Mat2Tensor(image, 1/255.0); // Input Tensor/Ops Create const std::vector<TF_Tensor*> input_tensors = {input_image}; const std::vector<TF_Output> input_ops = {TFU.GetOperationByName("input_image_input", 0)}; // Output Tensor/Ops Create const std::vector<TF_Output> output_ops = {TFU.GetOperationByName("output_class/Softmax", 0)}; std::vector<TF_Tensor*> output_tensors = {nullptr}; status = TFU.RunSession(input_ops, input_tensors, output_ops, output_tensors); if (status == TFUtils::SUCCESS) { const std::vector<std::vector<float>> data = TFUtils::GetTensorsData<float>(output_tensors); const std::vector<float> result = data[0]; int pred_index = ArgMax(result); // Print test accuracy printf("Predict: %d Label: %s", pred_index, class_names[pred_index].c_str()); } else { std::cout << "Error run session"; return 2; } TFUtils::DeleteTensors(input_tensors); TFUtils::DeleteTensors(output_tensors); return 0; } int ArgMax(const std::vector<float> result) { float max_value = -1.0; int max_index = -1; const long count = result.size(); for (int i = 0; i < count; ++i) { const float value = result[i]; if (value > max_value) { max_index = i; max_value = value; } std::cout << "value[" << i << "] = " << value << std::endl; } return max_index; } |
编译运行这一文件,如果没有问题,则会得到如下输出:
1 2 3 4 5 6 7 8 9 10 11 | value[0] = 6.40457e-09 value[1] = 2.41816e-07 value[2] = 3.60118e-08 value[3] = 1.18324e-09 value[4] = 6.13108e-11 value[5] = 0.00021271 value[6] = 2.01991e-11 value[7] = 3.94614e-05 value[8] = 1.17029e-10 value[9] = 0.999748 Predict: 9 Label: Ankle boot |
到此,我们就完成了使用 C API 运行 Tensorflow Model 的流程。
本文中的全部代码均已开源:
https://github.com/skylook/tensorflow_cpp
参考文献
[1] https://github.com/Neargye/hello_tf_c_api
[2] https://www.tensorflow.org/api_docs/python/tf/contrib/saved_model/save_keras_model
[3] https://www.tensorflow.org/install/lang_c
大佬我用libtensorflow-cpu-windows-x86_64-2.3.0生成的时候提示我缺少libtensorflow.so 这个怎么解决?
你好我在window上编译的时候,生成是会报错缺少libtensorflow.so,是不是还要重新编c++?
按理说bazel不应该需要重新编译,不过我并没有windows也很久没弄过了不确定是什么问题。以后有机会研究的话再更新。