Save data graph into MindIR and Load

This commit is contained in:
luoyang 2021-08-30 16:29:03 +08:00
parent 6e5637e9e1
commit c799582bbd
23 changed files with 420 additions and 15 deletions

View File

@ -75,6 +75,17 @@ class MS_API Model {
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
/// \brief Inference model, only for cv model inference.
///
/// \param[in] inputs A string represents the file path of input image.
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
/// \param[in] before CallBack before predict.
/// \param[in] after CallBack after predict.
///
/// \return Status.
inline Status Predict(const std::string &input, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
/// \brief Load config file.
///
/// \param[in] config_path config file path.
@ -190,6 +201,8 @@ class MS_API Model {
std::vector<std::vector<char>> GetOutputTensorNamesChar();
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
Status Predict(const std::vector<char> &input, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
const MSKernelCallBack &after);
std::shared_ptr<ModelImpl> impl_;
};
@ -207,5 +220,10 @@ MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &node_name) {
return GetOutputsByNodeName(StringToChar(node_name));
}
Status Model::Predict(const std::string &input, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
const MSKernelCallBack &after) {
return Predict(StringToChar(input), outputs, before, after);
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H

View File

@ -1,4 +1,6 @@
# build mindspore_shared_lib
include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc)
include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset)
if(NOT(BUILD_LITE))
set(LOAD_MINDIR_SRC
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc

View File

@ -21,7 +21,7 @@
namespace mindspore {
Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type)
: func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) {
: func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType), data_graph_({}) {
if (model_type != ModelType::kMindIR) {
MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type;
}
@ -30,7 +30,7 @@ Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model
}
Graph::GraphData::GraphData(const Buffer &om_data, enum ModelType model_type)
: func_graph_(nullptr), om_data_(om_data), model_type_(model_type) {
: func_graph_(nullptr), om_data_(om_data), model_type_(model_type), data_graph_({}) {
if (model_type_ != ModelType::kOM) {
MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type_;
}
@ -70,4 +70,8 @@ Buffer Graph::GraphData::GetOMData() const {
return om_data_;
}
void Graph::GraphData::SetPreprocess(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph) {
data_graph_ = data_graph;
}
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include <memory>
#include "include/api/graph.h"
#include "include/api/types.h"
#include "include/dataset/execute.h"
#include "ir/func_graph.h"
namespace mindspore {
@ -41,10 +42,15 @@ class Graph::GraphData {
Buffer GetOMData() const;
void SetPreprocess(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph);
std::vector<std::shared_ptr<dataset::Execute>> GetPreprocess() { return data_graph_; }
private:
FuncGraphPtr func_graph_;
Buffer om_data_;
enum ModelType model_type_;
std::vector<std::shared_ptr<dataset::Execute>> data_graph_;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H

View File

@ -94,6 +94,15 @@ Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor>
return impl_->Predict(inputs, outputs);
}
Status Model::Predict(const std::vector<char> &input, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
const MSKernelCallBack &after) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Failed because this model has not been built.";
return kMCFailed;
}
return impl_->Predict(CharToString(input), outputs);
}
std::vector<MSTensor> Model::GetInputs() {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Failed because this model has not been built.";

View File

@ -15,6 +15,9 @@
*/
#include "cxx_api/model/model_impl.h"
#include <fstream>
#include "debug/common.h"
namespace mindspore {
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
@ -41,4 +44,85 @@ Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
return kSuccess;
}
Status ModelImpl::Predict(const std::string &input, std::vector<MSTensor> *outputs) {
#if !defined(_WIN32) && !defined(_WIN64)
auto realpath = Common::GetRealPath(input);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << input;
return Status(kMEInvalidInput, "Get real path failed, path=" + input);
}
MS_EXCEPTION_IF_NULL(outputs);
// Read image file
auto file = realpath.value();
if (file.empty()) {
return Status(kMEInvalidInput, "can not find any input file.");
}
std::ifstream ifs(file, std::ios::in | std::ios::binary);
if (!ifs.good()) {
return Status(kMEInvalidInput, "File: " + file + " does not exist.");
}
if (!ifs.is_open()) {
return Status(kMEInvalidInput, "File: " + file + " open failed.");
}
auto &io_seekg1 = ifs.seekg(0, std::ios::end);
if (!io_seekg1.good() || io_seekg1.fail() || io_seekg1.bad()) {
ifs.close();
return Status(kMEInvalidInput, "Failed to seekg file: " + file);
}
size_t size = ifs.tellg();
MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
auto &io_seekg2 = ifs.seekg(0, std::ios::beg);
if (!io_seekg2.good() || io_seekg2.fail() || io_seekg2.bad()) {
ifs.close();
return Status(kMEInvalidInput, "Failed to seekg file: " + file);
}
auto &io_read = ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
if (!io_read.good() || io_read.fail() || io_read.bad()) {
ifs.close();
return Status(kMEInvalidInput, "Failed to read file: " + file);
}
ifs.close();
// Run preprocess
std::vector<MSTensor> transform_inputs;
std::vector<MSTensor> transform_outputs;
transform_inputs.emplace_back(std::move(buffer));
MS_LOG(DEBUG) << "transform_inputs[0].Shape: " << transform_inputs[0].Shape();
auto preprocessor = graph_->graph_data_->GetPreprocess();
if (!preprocessor.empty()) {
for (auto exes : preprocessor) {
MS_EXCEPTION_IF_NULL(exes);
Status ret = exes->operator()(transform_inputs, &transform_outputs);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Run preprocess failed.";
return ret;
}
MS_LOG(DEBUG) << "transform_outputs[0].Shape: " << transform_outputs[0].Shape();
transform_inputs = transform_outputs;
}
} else {
std::string msg = "Attempt to predict with data preprocess, but no preprocess operation is defined in MindIR.";
MS_LOG(ERROR) << msg;
return Status(kMEFailed, msg);
}
// Run prediction
Status ret = Predict(transform_outputs, outputs);
if (ret != kSuccess) {
MS_LOG(ERROR) << ret.GetErrDescription();
return ret;
}
return kSuccess;
#else
MS_LOG(ERROR) << "Predict with data preprocess is not supported on Windows yet.";
return Status(kMEFailed, "Predict with data preprocess is not supported on Windows yet.");
#endif
}
} // namespace mindspore

View File

@ -39,6 +39,8 @@ class ModelImpl {
virtual Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
virtual Status Predict(const std::string &input, std::vector<MSTensor> *outputs);
virtual std::vector<MSTensor> GetInputs() = 0;
virtual std::vector<MSTensor> GetOutputs() = 0;

View File

@ -19,6 +19,10 @@
#include "cxx_api/graph/graph_data.h"
#include "utils/log_adapter.h"
#include "mindspore/core/load_mindir/load_model.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/serdes.h"
#include "minddata/dataset/include/dataset/execute.h"
#endif
#include "utils/crypto.h"
namespace mindspore {
@ -187,7 +191,24 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
MS_LOG(ERROR) << err_msg.str();
return Status(kMEInvalidInput, err_msg.str());
}
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
auto graph_data = std::make_shared<Graph::GraphData>(anf_graph, kMindIR);
#if !defined(_WIN32) && !defined(_WIN64)
std::string preprocessor = LoadPreprocess(file_path);
if (!preprocessor.empty()) {
std::vector<std::shared_ptr<dataset::Execute>> data_graph;
status = dataset::Serdes::ParseMindIRPreprocess(preprocessor, "image", &data_graph);
if (status != kSuccess) {
MS_LOG(ERROR) << status.GetErrDescription();
return status;
}
if (!data_graph.empty()) {
graph_data->SetPreprocess(data_graph);
} else {
MS_LOG(WARNING) << "Load preprocess failed, no data preprocess operations found in MindIR.";
}
}
#endif
*graph = Graph(graph_data);
return kSuccess;
} else if (model_type == kOM) {
Buffer data = ReadFile(file_path);
@ -256,7 +277,24 @@ Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelTyp
MS_LOG(ERROR) << err_msg.str();
return Status(kMEInvalidInput, err_msg.str());
}
results.emplace_back(std::make_shared<Graph::GraphData>(anf_graphs[i], kMindIR));
auto graph_data = std::make_shared<Graph::GraphData>(anf_graphs[i], kMindIR);
#if !defined(_WIN32) && !defined(_WIN64)
std::string preprocessor = LoadPreprocess(files_path[i]);
if (!preprocessor.empty()) {
std::vector<std::shared_ptr<dataset::Execute>> data_graph;
auto status = dataset::Serdes::ParseMindIRPreprocess(preprocessor, "image", &data_graph);
if (status != kSuccess) {
MS_LOG(ERROR) << status.GetErrDescription();
return status;
}
if (!data_graph.empty()) {
graph_data->SetPreprocess(data_graph);
} else {
MS_LOG(WARNING) << "Load preprocess failed, no data preprocess operations found in MindIR.";
}
}
#endif
results.emplace_back(graph_data);
}
*graphs = std::move(results);

View File

@ -159,7 +159,7 @@ class Tensor {
template <typename T>
static Status CreateFromVector(const std::vector<T> &items, const TensorShape &shape, TensorPtr *out) {
CHECK_FAIL_RETURN_UNEXPECTED(
items.size() == shape.NumOfElements(),
static_cast<dsize_t>(items.size()) == shape.NumOfElements(),
"Number of elements in the vector does not match the number of elements of the shape required");
DataType type = DataType::FromCType<T>();
// if items is empty, items_ptr would be nullptr. CreateFromMemory will handle this case.
@ -419,7 +419,7 @@ class Tensor {
return {};
}
std::vector<dsize_t> indices(index_vector.size(), 0);
for (int i = 0; i < index_vector.size(); i++) {
for (size_t i = 0; i < index_vector.size(); i++) {
indices[i] = HandleNeg(index_vector[i], length[i]);
}
return indices;
@ -786,7 +786,7 @@ inline Status Tensor::CreateFromVector<std::string>(const std::vector<std::strin
TensorPtr *out) {
RETURN_UNEXPECTED_IF_NULL(out);
CHECK_FAIL_RETURN_UNEXPECTED(
items.size() == shape.NumOfElements(),
static_cast<dsize_t>(items.size()) == shape.NumOfElements(),
"Number of elements in the vector does not match the number of elements of the shape required");
const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator();
*out = std::allocate_shared<Tensor>(*alloc, TensorShape({static_cast<dsize_t>(items.size())}),

View File

@ -19,7 +19,7 @@
#include <memory>
#include <vector>
#include "mindspore/ccsrc/minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/include/dataset/constants.h"
namespace mindspore {

View File

@ -160,7 +160,7 @@ class Connector {
// Get current size of connector.
int32_t size() const {
int32_t size = 0;
for (int32_t i = 0; i < queues_.size(); ++i) {
for (size_t i = 0; i < queues_.size(); ++i) {
size += queues_[i]->size();
}
return size;
@ -168,7 +168,7 @@ class Connector {
int32_t capacity() const {
int32_t capacity = 0;
for (int32_t i = 0; i < queues_.size(); ++i) {
for (size_t i = 0; i < queues_.size(); ++i) {
capacity += queues_[i]->capacity();
}
return capacity;

View File

@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <stack>
#include "minddata/dataset/engine/serdes.h"
#include "debug/common.h"
@ -307,5 +309,54 @@ Serdes::InitializeFuncPtr() {
return ops_ptr;
}
Status Serdes::ParseMindIRPreprocess(const std::string &dataset_json, const std::string &process_column,
std::vector<std::shared_ptr<mindspore::dataset::Execute>> *data_graph) {
CHECK_FAIL_RETURN_UNEXPECTED(!dataset_json.empty(), "Invalid data, no json data in dataset_json.");
nlohmann::json dataset_js;
try {
dataset_js = nlohmann::json::parse(dataset_json);
} catch (const std::exception &err) {
MS_LOG(ERROR) << "Invalid json content, failed to parse JSON data.";
RETURN_STATUS_UNEXPECTED("Invalid json content, failed to parse JSON data.");
}
// Note1: We have to consider if pipeline has multibranch, how to deal with this situation?
// op1 - map - |
// op2 - map - concat - map - ...
std::stack<nlohmann::json> reverse_traversal;
nlohmann::json dataset_nodes = dataset_js;
while (dataset_nodes != nullptr) {
reverse_traversal.push(dataset_nodes);
if (dataset_nodes["children"].size() > 1) {
MS_LOG(WARNING) << "Need to support dataset_node with more than one child.";
}
dataset_nodes = dataset_nodes["children"][0];
}
// Note2: We have to consider if the "image" column does not named with "image", how to select its map ops?
// In MindRecord, TFRecord, GeneratorDataset or RenameDataset, it seems that the column names are not fixed.
while (!reverse_traversal.empty()) {
nlohmann::json node = reverse_traversal.top();
reverse_traversal.pop();
if (node["op_type"] == "Map") {
std::vector<std::shared_ptr<TensorOperation>> tensor_ops;
RETURN_IF_NOT_OK(ConstructTensorOps(node["operations"], &tensor_ops));
if (node["input_columns"][0] == process_column) {
std::vector<std::string> op_names;
std::transform(tensor_ops.begin(), tensor_ops.end(), std::back_inserter(op_names),
[](const auto &op) { return op->Name(); });
MS_LOG(INFO) << "Find valid preprocess operations: " << op_names;
data_graph->push_back(std::make_shared<Execute>(tensor_ops));
}
}
}
if (!data_graph->size()) {
MS_LOG(WARNING) << "Can not find any valid preprocess operation.";
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -71,6 +71,7 @@
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/include/dataset/execute.h"
#include "minddata/dataset/include/dataset/iterator.h"
#include "minddata/dataset/include/dataset/samplers.h"
#include "minddata/dataset/include/dataset/transforms.h"
@ -176,6 +177,14 @@ class Serdes {
/// \return Status The status code returned
static Status ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result);
/// \brief helper function to load tensor operations from dataset JSON and construct Execute object.
/// \param[in] dataset_json JSON string of dataset.
/// \param[in] process_column Select all map operations which process this column.
/// \param[out] data_graph Execute object contains tensor operations of map.
/// \return Status The status code returned.
static Status ParseMindIRPreprocess(const std::string &dataset_json, const std::string &process_column,
std::vector<std::shared_ptr<mindspore::dataset::Execute>> *data_graph);
protected:
/// \brief Helper function to save JSON to a file
/// \param[in] json_string The JSON string to be saved to the file

View File

@ -101,13 +101,13 @@ Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, C alloc,
return Status(StatusCode::kMDOutOfMemory);
}
if (!std::is_arithmetic<T>::value) {
for (auto i = 0; i < n; i++) {
for (size_t i = 0; i < n; i++) {
std::allocator_traits<C>::construct(alloc, &(data[i]), std::forward<Args>(args)...);
}
}
auto deleter = [](T *p, C f_alloc, size_t f_n) {
if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) {
for (auto i = 0; i < f_n; ++i) {
for (size_t i = 0; i < f_n; ++i) {
std::allocator_traits<C>::destroy(f_alloc, &p[i]);
}
}

View File

@ -1,6 +1,8 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_BINARY_DIR})
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core)
include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc)
include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset)
add_subdirectory(gvar)
if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC)

View File

@ -169,6 +169,33 @@ bool ParseGraphProto(mind_ir::GraphProto *graph, const std::string &path, const
return true;
}
std::string LoadPreprocess(const std::string &file_name) {
if (file_name.length() > PATH_MAX) {
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
return nullptr;
}
const char *file_path = file_name.c_str();
char abs_path_buff[PATH_MAX];
#ifdef _WIN32
_fullpath(abs_path_buff, file_path, PATH_MAX);
#else
if (!realpath(file_path, abs_path_buff)) {
MS_LOG(ERROR) << "Load MindIR get absolute path failed";
}
#endif
// Read graph
mind_ir::ModelProto origin_model;
std::fstream mindir_stream(std::string(std::string(abs_path_buff)), std::ios::in | std::ios::binary);
if (!mindir_stream || !origin_model.ParseFromIstream(&mindir_stream)) {
MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
return std::string();
}
return origin_model.preprocessor();
}
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> file_names, bool is_lite,
const unsigned char *dec_key, const size_t key_len,
const std::string &dec_mode, bool inc_load) {

View File

@ -30,6 +30,7 @@ std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::strin
const unsigned char *dec_key = nullptr, const size_t key_len = 0,
const std::string &dec_mode = std::string("AES-GCM"),
bool inc_load = true);
std::string LoadPreprocess(const std::string &file_name);
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
} // namespace mindspore

View File

@ -76,6 +76,7 @@ message ModelProto {
optional string doc_string = 6;
optional GraphProto graph = 7;
repeated GraphProto functions = 8; // all the graphs without the main graph.
optional string preprocessor = 9; // data graph from MindData.
}

View File

@ -43,7 +43,8 @@ HEADER_LOCATION="-I${MINDSPORE_HOME}
-I${FLATBUFFERS}
-I${MINDSPORE_HOME}/mindspore/lite/build/schema
-I${MINDSPORE_HOME}/mindspore/lite/build/schema/inner
-I${MINDSPORE_HOME}/mindspore/ccsrc/backend/kernel_compiler/cpu"
-I${MINDSPORE_HOME}/mindspore/ccsrc/backend/kernel_compiler/cpu
-I${MINDSPORE_HOME}/mindspore/ccsrc/minddata/dataset"
REMOVE_LISTS_STR=""
getDeep() {

View File

@ -20,12 +20,14 @@ import math
import shutil
import time
import copy
import json
import threading
from threading import Thread, Lock
from collections import defaultdict
import numpy as np
import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore import log as logger
@ -715,6 +717,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
- enc_key (byte): Byte type key used for encryption. Tha valid length is 16, 24, or 32.
- enc_mode (str): Specifies the encryption mode, take effect when enc_key is set.
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
- dataset (str): Specifies the preprocess methods of network.
Examples:
>>> import numpy as np
@ -737,9 +740,10 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
enc_mode = 'AES-GCM'
if 'enc_mode' in kwargs.keys():
enc_mode = Validator.check_isinstance('enc_mode', kwargs['enc_mode'], str)
_export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode)
dataset = kwargs['dataset'] if 'dataset' in kwargs.keys() else None
_export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode, dataset=dataset)
else:
_export(net, file_name, file_format, *inputs)
_export(net, file_name, file_format, *inputs, **kwargs)
def _export(net, file_name, file_format, *inputs, **kwargs):
@ -748,6 +752,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
"""
logger.info("exporting model file:%s format:%s.", file_name, file_format)
check_input_data(*inputs, data_class=Tensor)
if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None:
check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset)
if file_format == 'GEIR':
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
@ -808,6 +814,10 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
net_dict = net.parameters_dict()
model.ParseFromString(mindir_stream)
if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None:
dataset = kwargs['dataset']
model.preprocessor = json.dumps(dataset.to_json(), indent=2)
save_together = _save_together(net_dict, model)
is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
if save_together:

View File

@ -12,3 +12,5 @@ file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main main.cc utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
add_executable(main_preprocess main_preprocess.cc utils.cc)
target_link_libraries(main_preprocess ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,81 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "inc/utils.h"
using mindspore::Context;
using mindspore::GraphCell;
using mindspore::Model;
using mindspore::ModelType;
using mindspore::MSTensor;
using mindspore::Serialization;
using mindspore::Status;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(dataset_path, ".", "dataset path");
DEFINE_string(image_path, ".", "image path");
DEFINE_int32(device_id, 0, "device id");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret.IsError()) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
std::vector<MSTensor> outputs;
ret = model.Predict(FLAGS_image_path, &outputs);
if (ret.IsError()) {
std::cout << "ERROR: Predict failed." << std::endl;
return 1;
}
auto shape = outputs[0].Shape();
std::cout << "Output Shape: " << std::endl;
for (auto s : shape) {
std::cout << s << ", ";
}
std::cout << std::endl;
return 0;
}

View File

@ -0,0 +1,57 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
resnext export mindir.
"""
import os
import numpy as np
from mindspore.common import dtype as mstype
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.image_classification import get_network
from src.utils.auto_mixed_precision import auto_mixed_precision
from src.dataset import classification_dataset
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=config.device_id)
def modelarts_pre_process():
'''modelarts pre process function.'''
config.file_name = os.path.join(config.output_path, config.file_name)
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_export():
"""run export."""
network = get_network(network=config.network, num_classes=config.num_classes, platform=config.device_target)
param_dict = load_checkpoint(config.checkpoint_file_path)
load_param_into_net(network, param_dict)
if config.device_target == "Ascend":
network.to_float(mstype.float16)
else:
auto_mixed_precision(network)
network.set_train(False)
input_shp = [config.batch_size, 3, config.height, config.width]
de_dataset = classification_dataset("src/", config.image_size, config.per_batch_size, 1, 0, 1, mode="eval")
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
export(network, input_array, file_name=config.file_name, file_format=config.file_format, dataset=de_dataset)
if __name__ == '__main__':
run_export()