diff --git a/include/api/model.h b/include/api/model.h index c926027470b..22c8d2940ff 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -75,6 +75,17 @@ class MS_API Model { Status Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); + /// \brief Inference model, only for cv model inference. + /// + /// \param[in] inputs A string represents the file path of input image. + /// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence. + /// \param[in] before CallBack before predict. + /// \param[in] after CallBack after predict. + /// + /// \return Status. + inline Status Predict(const std::string &input, std::vector *outputs, + const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); + /// \brief Load config file. /// /// \param[in] config_path config file path. @@ -190,6 +201,8 @@ class MS_API Model { std::vector> GetOutputTensorNamesChar(); MSTensor GetOutputByTensorName(const std::vector &tensor_name); std::vector GetOutputsByNodeName(const std::vector &node_name); + Status Predict(const std::vector &input, std::vector *outputs, const MSKernelCallBack &before, + const MSKernelCallBack &after); std::shared_ptr impl_; }; @@ -207,5 +220,10 @@ MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) { std::vector Model::GetOutputsByNodeName(const std::string &node_name) { return GetOutputsByNodeName(StringToChar(node_name)); } + +Status Model::Predict(const std::string &input, std::vector *outputs, const MSKernelCallBack &before, + const MSKernelCallBack &after) { + return Predict(StringToChar(input), outputs, before, after); +} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_MODEL_H diff --git a/mindspore/ccsrc/cxx_api/CMakeLists.txt b/mindspore/ccsrc/cxx_api/CMakeLists.txt index 631ec3180d0..da27374fbff 100644 --- a/mindspore/ccsrc/cxx_api/CMakeLists.txt +++ b/mindspore/ccsrc/cxx_api/CMakeLists.txt @@ -1,4 +1,6 @@ # build mindspore_shared_lib +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc) +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset) if(NOT(BUILD_LITE)) set(LOAD_MINDIR_SRC ${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc diff --git a/mindspore/ccsrc/cxx_api/graph/graph_data.cc b/mindspore/ccsrc/cxx_api/graph/graph_data.cc index 85e1d8e6ce6..e2d6155a6b7 100644 --- a/mindspore/ccsrc/cxx_api/graph/graph_data.cc +++ b/mindspore/ccsrc/cxx_api/graph/graph_data.cc @@ -21,7 +21,7 @@ namespace mindspore { Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type) - : func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) { + : func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType), data_graph_({}) { if (model_type != ModelType::kMindIR) { MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type; } @@ -30,7 +30,7 @@ Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model } Graph::GraphData::GraphData(const Buffer &om_data, enum ModelType model_type) - : func_graph_(nullptr), om_data_(om_data), model_type_(model_type) { + : func_graph_(nullptr), om_data_(om_data), model_type_(model_type), data_graph_({}) { if (model_type_ != ModelType::kOM) { MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type_; } @@ -70,4 +70,8 @@ Buffer Graph::GraphData::GetOMData() const { return om_data_; } + +void Graph::GraphData::SetPreprocess(const std::vector> &data_graph) { + data_graph_ = data_graph; +} } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/graph/graph_data.h b/mindspore/ccsrc/cxx_api/graph/graph_data.h index 8eefd2f6751..168eeb531ae 100644 --- a/mindspore/ccsrc/cxx_api/graph/graph_data.h +++ b/mindspore/ccsrc/cxx_api/graph/graph_data.h @@ -22,6 +22,7 @@ #include #include "include/api/graph.h" #include "include/api/types.h" +#include "include/dataset/execute.h" #include "ir/func_graph.h" namespace mindspore { @@ -41,10 +42,15 @@ class Graph::GraphData { Buffer GetOMData() const; + void SetPreprocess(const std::vector> &data_graph); + + std::vector> GetPreprocess() { return data_graph_; } + private: FuncGraphPtr func_graph_; Buffer om_data_; enum ModelType model_type_; + std::vector> data_graph_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index f6282fa5177..a02bb176294 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -94,6 +94,15 @@ Status Model::Predict(const std::vector &inputs, std::vector return impl_->Predict(inputs, outputs); } +Status Model::Predict(const std::vector &input, std::vector *outputs, const MSKernelCallBack &before, + const MSKernelCallBack &after) { + if (impl_ == nullptr) { + MS_LOG(ERROR) << "Failed because this model has not been built."; + return kMCFailed; + } + return impl_->Predict(CharToString(input), outputs); +} + std::vector Model::GetInputs() { if (impl_ == nullptr) { MS_LOG(ERROR) << "Failed because this model has not been built."; diff --git a/mindspore/ccsrc/cxx_api/model/model_impl.cc b/mindspore/ccsrc/cxx_api/model/model_impl.cc index 176777a9475..07e8054bc7a 100644 --- a/mindspore/ccsrc/cxx_api/model/model_impl.cc +++ b/mindspore/ccsrc/cxx_api/model/model_impl.cc @@ -15,6 +15,9 @@ */ #include "cxx_api/model/model_impl.h" +#include +#include "debug/common.h" + namespace mindspore { Status ModelImpl::Predict(const std::vector &inputs, std::vector *outputs) { MS_EXCEPTION_IF_NULL(outputs); @@ -41,4 +44,85 @@ Status ModelImpl::Predict(const std::vector &inputs, std::vector *outputs) { +#if !defined(_WIN32) && !defined(_WIN64) + auto realpath = Common::GetRealPath(input); + if (!realpath.has_value()) { + MS_LOG(ERROR) << "Get real path failed, path=" << input; + return Status(kMEInvalidInput, "Get real path failed, path=" + input); + } + MS_EXCEPTION_IF_NULL(outputs); + + // Read image file + auto file = realpath.value(); + if (file.empty()) { + return Status(kMEInvalidInput, "can not find any input file."); + } + + std::ifstream ifs(file, std::ios::in | std::ios::binary); + if (!ifs.good()) { + return Status(kMEInvalidInput, "File: " + file + " does not exist."); + } + if (!ifs.is_open()) { + return Status(kMEInvalidInput, "File: " + file + " open failed."); + } + + auto &io_seekg1 = ifs.seekg(0, std::ios::end); + if (!io_seekg1.good() || io_seekg1.fail() || io_seekg1.bad()) { + ifs.close(); + return Status(kMEInvalidInput, "Failed to seekg file: " + file); + } + + size_t size = ifs.tellg(); + MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast(size)}, nullptr, size); + + auto &io_seekg2 = ifs.seekg(0, std::ios::beg); + if (!io_seekg2.good() || io_seekg2.fail() || io_seekg2.bad()) { + ifs.close(); + return Status(kMEInvalidInput, "Failed to seekg file: " + file); + } + + auto &io_read = ifs.read(reinterpret_cast(buffer.MutableData()), size); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + ifs.close(); + return Status(kMEInvalidInput, "Failed to read file: " + file); + } + ifs.close(); + + // Run preprocess + std::vector transform_inputs; + std::vector transform_outputs; + transform_inputs.emplace_back(std::move(buffer)); + MS_LOG(DEBUG) << "transform_inputs[0].Shape: " << transform_inputs[0].Shape(); + auto preprocessor = graph_->graph_data_->GetPreprocess(); + if (!preprocessor.empty()) { + for (auto exes : preprocessor) { + MS_EXCEPTION_IF_NULL(exes); + Status ret = exes->operator()(transform_inputs, &transform_outputs); + if (ret != kSuccess) { + MS_LOG(ERROR) << "Run preprocess failed."; + return ret; + } + MS_LOG(DEBUG) << "transform_outputs[0].Shape: " << transform_outputs[0].Shape(); + transform_inputs = transform_outputs; + } + } else { + std::string msg = "Attempt to predict with data preprocess, but no preprocess operation is defined in MindIR."; + MS_LOG(ERROR) << msg; + return Status(kMEFailed, msg); + } + + // Run prediction + Status ret = Predict(transform_outputs, outputs); + if (ret != kSuccess) { + MS_LOG(ERROR) << ret.GetErrDescription(); + return ret; + } + return kSuccess; +#else + MS_LOG(ERROR) << "Predict with data preprocess is not supported on Windows yet."; + return Status(kMEFailed, "Predict with data preprocess is not supported on Windows yet."); +#endif +} } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/model_impl.h b/mindspore/ccsrc/cxx_api/model/model_impl.h index 8dde62d7f57..d672c096433 100644 --- a/mindspore/ccsrc/cxx_api/model/model_impl.h +++ b/mindspore/ccsrc/cxx_api/model/model_impl.h @@ -39,6 +39,8 @@ class ModelImpl { virtual Status Predict(const std::vector &inputs, std::vector *outputs); + virtual Status Predict(const std::string &input, std::vector *outputs); + virtual std::vector GetInputs() = 0; virtual std::vector GetOutputs() = 0; diff --git a/mindspore/ccsrc/cxx_api/serialization.cc b/mindspore/ccsrc/cxx_api/serialization.cc index ec45b813f69..b6c9b72351e 100644 --- a/mindspore/ccsrc/cxx_api/serialization.cc +++ b/mindspore/ccsrc/cxx_api/serialization.cc @@ -19,6 +19,10 @@ #include "cxx_api/graph/graph_data.h" #include "utils/log_adapter.h" #include "mindspore/core/load_mindir/load_model.h" +#if !defined(_WIN32) && !defined(_WIN64) +#include "minddata/dataset/engine/serdes.h" +#include "minddata/dataset/include/dataset/execute.h" +#endif #include "utils/crypto.h" namespace mindspore { @@ -187,7 +191,24 @@ Status Serialization::Load(const std::vector &file, ModelType model_type, MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } - *graph = Graph(std::make_shared(anf_graph, kMindIR)); + auto graph_data = std::make_shared(anf_graph, kMindIR); +#if !defined(_WIN32) && !defined(_WIN64) + std::string preprocessor = LoadPreprocess(file_path); + if (!preprocessor.empty()) { + std::vector> data_graph; + status = dataset::Serdes::ParseMindIRPreprocess(preprocessor, "image", &data_graph); + if (status != kSuccess) { + MS_LOG(ERROR) << status.GetErrDescription(); + return status; + } + if (!data_graph.empty()) { + graph_data->SetPreprocess(data_graph); + } else { + MS_LOG(WARNING) << "Load preprocess failed, no data preprocess operations found in MindIR."; + } + } +#endif + *graph = Graph(graph_data); return kSuccess; } else if (model_type == kOM) { Buffer data = ReadFile(file_path); @@ -256,7 +277,24 @@ Status Serialization::Load(const std::vector> &files, ModelTyp MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } - results.emplace_back(std::make_shared(anf_graphs[i], kMindIR)); + auto graph_data = std::make_shared(anf_graphs[i], kMindIR); +#if !defined(_WIN32) && !defined(_WIN64) + std::string preprocessor = LoadPreprocess(files_path[i]); + if (!preprocessor.empty()) { + std::vector> data_graph; + auto status = dataset::Serdes::ParseMindIRPreprocess(preprocessor, "image", &data_graph); + if (status != kSuccess) { + MS_LOG(ERROR) << status.GetErrDescription(); + return status; + } + if (!data_graph.empty()) { + graph_data->SetPreprocess(data_graph); + } else { + MS_LOG(WARNING) << "Load preprocess failed, no data preprocess operations found in MindIR."; + } + } +#endif + results.emplace_back(graph_data); } *graphs = std::move(results); diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.h b/mindspore/ccsrc/minddata/dataset/core/tensor.h index 3c6833049a8..e00ec00cff4 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor.h +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.h @@ -159,7 +159,7 @@ class Tensor { template static Status CreateFromVector(const std::vector &items, const TensorShape &shape, TensorPtr *out) { CHECK_FAIL_RETURN_UNEXPECTED( - items.size() == shape.NumOfElements(), + static_cast(items.size()) == shape.NumOfElements(), "Number of elements in the vector does not match the number of elements of the shape required"); DataType type = DataType::FromCType(); // if items is empty, items_ptr would be nullptr. CreateFromMemory will handle this case. @@ -419,7 +419,7 @@ class Tensor { return {}; } std::vector indices(index_vector.size(), 0); - for (int i = 0; i < index_vector.size(); i++) { + for (size_t i = 0; i < index_vector.size(); i++) { indices[i] = HandleNeg(index_vector[i], length[i]); } return indices; @@ -786,7 +786,7 @@ inline Status Tensor::CreateFromVector(const std::vector(items.size()) == shape.NumOfElements(), "Number of elements in the vector does not match the number of elements of the shape required"); const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); *out = std::allocate_shared(*alloc, TensorShape({static_cast(items.size())}), diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_helpers.h b/mindspore/ccsrc/minddata/dataset/core/tensor_helpers.h index 7713414b579..e686760e943 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor_helpers.h +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_helpers.h @@ -19,7 +19,7 @@ #include #include -#include "mindspore/ccsrc/minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/include/dataset/transforms.h" #include "minddata/dataset/include/dataset/constants.h" namespace mindspore { diff --git a/mindspore/ccsrc/minddata/dataset/engine/connector.h b/mindspore/ccsrc/minddata/dataset/engine/connector.h index b262ba63dbb..98c2333e202 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/connector.h +++ b/mindspore/ccsrc/minddata/dataset/engine/connector.h @@ -160,7 +160,7 @@ class Connector { // Get current size of connector. int32_t size() const { int32_t size = 0; - for (int32_t i = 0; i < queues_.size(); ++i) { + for (size_t i = 0; i < queues_.size(); ++i) { size += queues_[i]->size(); } return size; @@ -168,7 +168,7 @@ class Connector { int32_t capacity() const { int32_t capacity = 0; - for (int32_t i = 0; i < queues_.size(); ++i) { + for (size_t i = 0; i < queues_.size(); ++i) { capacity += queues_[i]->capacity(); } return capacity; diff --git a/mindspore/ccsrc/minddata/dataset/engine/serdes.cc b/mindspore/ccsrc/minddata/dataset/engine/serdes.cc index 1cc567aac8c..3a8e30a746c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/serdes.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/serdes.cc @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include +#include #include "minddata/dataset/engine/serdes.h" #include "debug/common.h" @@ -307,5 +309,54 @@ Serdes::InitializeFuncPtr() { return ops_ptr; } +Status Serdes::ParseMindIRPreprocess(const std::string &dataset_json, const std::string &process_column, + std::vector> *data_graph) { + CHECK_FAIL_RETURN_UNEXPECTED(!dataset_json.empty(), "Invalid data, no json data in dataset_json."); + + nlohmann::json dataset_js; + try { + dataset_js = nlohmann::json::parse(dataset_json); + } catch (const std::exception &err) { + MS_LOG(ERROR) << "Invalid json content, failed to parse JSON data."; + RETURN_STATUS_UNEXPECTED("Invalid json content, failed to parse JSON data."); + } + + // Note1: We have to consider if pipeline has multibranch, how to deal with this situation? + // op1 - map - | + // op2 - map - concat - map - ... + std::stack reverse_traversal; + nlohmann::json dataset_nodes = dataset_js; + while (dataset_nodes != nullptr) { + reverse_traversal.push(dataset_nodes); + if (dataset_nodes["children"].size() > 1) { + MS_LOG(WARNING) << "Need to support dataset_node with more than one child."; + } + dataset_nodes = dataset_nodes["children"][0]; + } + + // Note2: We have to consider if the "image" column does not named with "image", how to select its map ops? + // In MindRecord, TFRecord, GeneratorDataset or RenameDataset, it seems that the column names are not fixed. + while (!reverse_traversal.empty()) { + nlohmann::json node = reverse_traversal.top(); + reverse_traversal.pop(); + if (node["op_type"] == "Map") { + std::vector> tensor_ops; + RETURN_IF_NOT_OK(ConstructTensorOps(node["operations"], &tensor_ops)); + if (node["input_columns"][0] == process_column) { + std::vector op_names; + std::transform(tensor_ops.begin(), tensor_ops.end(), std::back_inserter(op_names), + [](const auto &op) { return op->Name(); }); + MS_LOG(INFO) << "Find valid preprocess operations: " << op_names; + data_graph->push_back(std::make_shared(tensor_ops)); + } + } + } + + if (!data_graph->size()) { + MS_LOG(WARNING) << "Can not find any valid preprocess operation."; + } + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/serdes.h b/mindspore/ccsrc/minddata/dataset/engine/serdes.h index 20fa9300ab2..d63ca92fa64 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/serdes.h +++ b/mindspore/ccsrc/minddata/dataset/engine/serdes.h @@ -71,6 +71,7 @@ #include "minddata/dataset/include/dataset/constants.h" #include "minddata/dataset/include/dataset/datasets.h" +#include "minddata/dataset/include/dataset/execute.h" #include "minddata/dataset/include/dataset/iterator.h" #include "minddata/dataset/include/dataset/samplers.h" #include "minddata/dataset/include/dataset/transforms.h" @@ -176,6 +177,14 @@ class Serdes { /// \return Status The status code returned static Status ConstructTensorOps(nlohmann::json json_obj, std::vector> *result); + /// \brief helper function to load tensor operations from dataset JSON and construct Execute object. + /// \param[in] dataset_json JSON string of dataset. + /// \param[in] process_column Select all map operations which process this column. + /// \param[out] data_graph Execute object contains tensor operations of map. + /// \return Status The status code returned. + static Status ParseMindIRPreprocess(const std::string &dataset_json, const std::string &process_column, + std::vector> *data_graph); + protected: /// \brief Helper function to save JSON to a file /// \param[in] json_string The JSON string to be saved to the file diff --git a/mindspore/ccsrc/minddata/dataset/util/allocator.h b/mindspore/ccsrc/minddata/dataset/util/allocator.h index 6df5b1d6925..66893110134 100644 --- a/mindspore/ccsrc/minddata/dataset/util/allocator.h +++ b/mindspore/ccsrc/minddata/dataset/util/allocator.h @@ -101,13 +101,13 @@ Status MakeUnique(std::unique_ptr> *out, C alloc, return Status(StatusCode::kMDOutOfMemory); } if (!std::is_arithmetic::value) { - for (auto i = 0; i < n; i++) { + for (size_t i = 0; i < n; i++) { std::allocator_traits::construct(alloc, &(data[i]), std::forward(args)...); } } auto deleter = [](T *p, C f_alloc, size_t f_n) { if (!std::is_arithmetic::value && std::is_destructible::value) { - for (auto i = 0; i < f_n; ++i) { + for (size_t i = 0; i < f_n; ++i) { std::allocator_traits::destroy(f_alloc, &p[i]); } } diff --git a/mindspore/core/CMakeLists.txt b/mindspore/core/CMakeLists.txt index 6ec5835c886..617a967e87b 100644 --- a/mindspore/core/CMakeLists.txt +++ b/mindspore/core/CMakeLists.txt @@ -1,6 +1,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_SOURCE_DIR}/mindspore/core) +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc) +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset) add_subdirectory(gvar) if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC) diff --git a/mindspore/core/load_mindir/load_model.cc b/mindspore/core/load_mindir/load_model.cc index afc37e9ad45..8a835ceacf1 100644 --- a/mindspore/core/load_mindir/load_model.cc +++ b/mindspore/core/load_mindir/load_model.cc @@ -169,6 +169,33 @@ bool ParseGraphProto(mind_ir::GraphProto *graph, const std::string &path, const return true; } +std::string LoadPreprocess(const std::string &file_name) { + if (file_name.length() > PATH_MAX) { + MS_LOG(ERROR) << "The length of the file name exceeds the limit."; + return nullptr; + } + const char *file_path = file_name.c_str(); + char abs_path_buff[PATH_MAX]; + +#ifdef _WIN32 + _fullpath(abs_path_buff, file_path, PATH_MAX); +#else + if (!realpath(file_path, abs_path_buff)) { + MS_LOG(ERROR) << "Load MindIR get absolute path failed"; + } +#endif + + // Read graph + mind_ir::ModelProto origin_model; + std::fstream mindir_stream(std::string(std::string(abs_path_buff)), std::ios::in | std::ios::binary); + if (!mindir_stream || !origin_model.ParseFromIstream(&mindir_stream)) { + MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file."; + return std::string(); + } + + return origin_model.preprocessor(); +} + std::vector> LoadMindIRs(std::vector file_names, bool is_lite, const unsigned char *dec_key, const size_t key_len, const std::string &dec_mode, bool inc_load) { diff --git a/mindspore/core/load_mindir/load_model.h b/mindspore/core/load_mindir/load_model.h index 21833b87e40..22f9e1fe1c6 100644 --- a/mindspore/core/load_mindir/load_model.h +++ b/mindspore/core/load_mindir/load_model.h @@ -30,6 +30,7 @@ std::vector> LoadMindIRs(const std::vector> ReadProtoFile(const std::string &file); std::shared_ptr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false); } // namespace mindspore diff --git a/mindspore/core/proto/mind_ir.proto b/mindspore/core/proto/mind_ir.proto index 8d9c9ecc434..e540d835893 100644 --- a/mindspore/core/proto/mind_ir.proto +++ b/mindspore/core/proto/mind_ir.proto @@ -76,6 +76,7 @@ message ModelProto { optional string doc_string = 6; optional GraphProto graph = 7; repeated GraphProto functions = 8; // all the graphs without the main graph. + optional string preprocessor = 9; // data graph from MindData. } diff --git a/mindspore/lite/tools/cropper/build_cropper_config.sh b/mindspore/lite/tools/cropper/build_cropper_config.sh index f2e1fe6a1f9..90b574d1054 100644 --- a/mindspore/lite/tools/cropper/build_cropper_config.sh +++ b/mindspore/lite/tools/cropper/build_cropper_config.sh @@ -43,7 +43,8 @@ HEADER_LOCATION="-I${MINDSPORE_HOME} -I${FLATBUFFERS} -I${MINDSPORE_HOME}/mindspore/lite/build/schema -I${MINDSPORE_HOME}/mindspore/lite/build/schema/inner --I${MINDSPORE_HOME}/mindspore/ccsrc/backend/kernel_compiler/cpu" +-I${MINDSPORE_HOME}/mindspore/ccsrc/backend/kernel_compiler/cpu +-I${MINDSPORE_HOME}/mindspore/ccsrc/minddata/dataset" REMOVE_LISTS_STR="" getDeep() { diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 873429d49d8..325c0127c3f 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -20,12 +20,14 @@ import math import shutil import time import copy +import json import threading from threading import Thread, Lock from collections import defaultdict import numpy as np +import mindspore import mindspore.nn as nn from mindspore import context from mindspore import log as logger @@ -715,6 +717,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs): - enc_key (byte): Byte type key used for encryption. Tha valid length is 16, 24, or 32. - enc_mode (str): Specifies the encryption mode, take effect when enc_key is set. Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'. + - dataset (str): Specifies the preprocess methods of network. Examples: >>> import numpy as np @@ -737,9 +740,10 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs): enc_mode = 'AES-GCM' if 'enc_mode' in kwargs.keys(): enc_mode = Validator.check_isinstance('enc_mode', kwargs['enc_mode'], str) - _export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode) + dataset = kwargs['dataset'] if 'dataset' in kwargs.keys() else None + _export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode, dataset=dataset) else: - _export(net, file_name, file_format, *inputs) + _export(net, file_name, file_format, *inputs, **kwargs) def _export(net, file_name, file_format, *inputs, **kwargs): @@ -748,6 +752,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs): """ logger.info("exporting model file:%s format:%s.", file_name, file_format) check_input_data(*inputs, data_class=Tensor) + if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None: + check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset) if file_format == 'GEIR': logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.") @@ -808,6 +814,10 @@ def _save_mindir(net, file_name, *inputs, **kwargs): net_dict = net.parameters_dict() model.ParseFromString(mindir_stream) + if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None: + dataset = kwargs['dataset'] + model.preprocessor = json.dumps(dataset.to_json(), indent=2) + save_together = _save_together(net_dict, model) is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys() if save_together: diff --git a/model_zoo/official/cv/resnext/ascend310_infer/src/CMakeLists.txt b/model_zoo/official/cv/resnext/ascend310_infer/src/CMakeLists.txt index 0397995b0e0..a1eb6819e39 100644 --- a/model_zoo/official/cv/resnext/ascend310_infer/src/CMakeLists.txt +++ b/model_zoo/official/cv/resnext/ascend310_infer/src/CMakeLists.txt @@ -12,3 +12,5 @@ file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) add_executable(main main.cc utils.cc) target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags) +add_executable(main_preprocess main_preprocess.cc utils.cc) +target_link_libraries(main_preprocess ${MS_LIB} ${MD_LIB} gflags) diff --git a/model_zoo/official/cv/resnext/ascend310_infer/src/main_preprocess.cc b/model_zoo/official/cv/resnext/ascend310_infer/src/main_preprocess.cc new file mode 100644 index 00000000000..f1449c3834d --- /dev/null +++ b/model_zoo/official/cv/resnext/ascend310_infer/src/main_preprocess.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "include/api/model.h" +#include "include/api/context.h" +#include "include/api/types.h" +#include "include/api/serialization.h" +#include "inc/utils.h" + +using mindspore::Context; +using mindspore::GraphCell; +using mindspore::Model; +using mindspore::ModelType; +using mindspore::MSTensor; +using mindspore::Serialization; +using mindspore::Status; + +DEFINE_string(mindir_path, "", "mindir path"); +DEFINE_string(dataset_path, ".", "dataset path"); +DEFINE_string(image_path, ".", "image path"); +DEFINE_int32(device_id, 0, "device id"); + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (RealPath(FLAGS_mindir_path).empty()) { + std::cout << "Invalid mindir" << std::endl; + return 1; + } + + auto context = std::make_shared(); + auto ascend310 = std::make_shared(); + ascend310->SetDeviceID(FLAGS_device_id); + context->MutableDeviceInfo().push_back(ascend310); + mindspore::Graph graph; + Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); + Model model; + Status ret = model.Build(GraphCell(graph), context); + if (ret.IsError()) { + std::cout << "ERROR: Build failed." << std::endl; + return 1; + } + + std::vector outputs; + ret = model.Predict(FLAGS_image_path, &outputs); + if (ret.IsError()) { + std::cout << "ERROR: Predict failed." << std::endl; + return 1; + } + + auto shape = outputs[0].Shape(); + std::cout << "Output Shape: " << std::endl; + for (auto s : shape) { + std::cout << s << ", "; + } + std::cout << std::endl; + + return 0; +} diff --git a/model_zoo/official/cv/resnext/export_datagraph.py b/model_zoo/official/cv/resnext/export_datagraph.py new file mode 100644 index 00000000000..0da8e6cde7b --- /dev/null +++ b/model_zoo/official/cv/resnext/export_datagraph.py @@ -0,0 +1,57 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +resnext export mindir. +""" +import os +import numpy as np +from mindspore.common import dtype as mstype +from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export +from src.model_utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper +from src.image_classification import get_network +from src.utils.auto_mixed_precision import auto_mixed_precision +from src.dataset import classification_dataset + + +context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) +if config.device_target == "Ascend": + context.set_context(device_id=config.device_id) + +def modelarts_pre_process(): + '''modelarts pre process function.''' + config.file_name = os.path.join(config.output_path, config.file_name) + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_export(): + """run export.""" + network = get_network(network=config.network, num_classes=config.num_classes, platform=config.device_target) + + param_dict = load_checkpoint(config.checkpoint_file_path) + load_param_into_net(network, param_dict) + if config.device_target == "Ascend": + network.to_float(mstype.float16) + else: + auto_mixed_precision(network) + network.set_train(False) + input_shp = [config.batch_size, 3, config.height, config.width] + + de_dataset = classification_dataset("src/", config.image_size, config.per_batch_size, 1, 0, 1, mode="eval") + + input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32)) + export(network, input_array, file_name=config.file_name, file_format=config.file_format, dataset=de_dataset) + +if __name__ == '__main__': + run_export()