forked from mindspore-Ecosystem/mindspore
!26949 save dataset into mindir - phase 2
Merge pull request !26949 from luoyang/mindir-stage2-2
This commit is contained in:
commit
a618bb8ab5
|
@ -92,7 +92,7 @@ class MS_API Model {
|
|||
/// \param[in] after CallBack after predict.
|
||||
///
|
||||
/// \return Status.
|
||||
Status PredictWithPreprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
Status PredictWithPreprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
|
||||
/// \brief Apply data preprocess if it exits in model.
|
||||
|
@ -101,7 +101,7 @@ class MS_API Model {
|
|||
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
|
||||
///
|
||||
/// \return Status.
|
||||
Status Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
Status Preprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs);
|
||||
|
||||
/// \brief Check if data preprocess exists in model.
|
||||
/// \return true if data preprocess exists.
|
||||
|
|
|
@ -104,12 +104,15 @@ class MS_API MSTensor {
|
|||
static inline MSTensor *CreateDevTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
|
||||
/// \brief Creates a MSTensor object from local image file, must be used in pairs with DestroyTensorPtr.
|
||||
/// \brief Creates a MSTensor object from local file, must be used in pairs with DestroyTensorPtr.
|
||||
///
|
||||
/// \param[in] image_file Path of image file.
|
||||
/// \param[in] file Path of file to be read.
|
||||
/// \param[in] type The data type of the MSTensor.
|
||||
/// \param[in] shape The shape of the MSTensor.
|
||||
///
|
||||
/// \return A pointer of MSTensor.
|
||||
static inline MSTensor *CreateImageTensor(const std::string &image_file) noexcept;
|
||||
static inline MSTensor *CreateTensorFromFile(const std::string &file, DataType type = DataType::kNumberTypeUInt8,
|
||||
const std::vector<int64_t> &shape = {}) noexcept;
|
||||
|
||||
/// \brief Create a string type MSTensor object whose data can be accessed by Model only after being copied, must be
|
||||
/// used in pair with DestroyTensorPtr.
|
||||
|
@ -268,7 +271,8 @@ class MS_API MSTensor {
|
|||
const void *data, size_t data_len) noexcept;
|
||||
static MSTensor *CreateDevTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
static MSTensor *CreateImageTensor(const std::vector<char> &image_file) noexcept;
|
||||
static MSTensor *CreateTensorFromFile(const std::vector<char> &file, enum DataType type,
|
||||
const std::vector<int64_t> &shape) noexcept;
|
||||
static MSTensor *CharStringsToTensor(const std::vector<char> &name, const std::vector<std::vector<char>> &str);
|
||||
static std::vector<std::vector<char>> TensorToStringChars(const MSTensor &tensor);
|
||||
|
||||
|
@ -316,8 +320,9 @@ MSTensor *MSTensor::CreateDevTensor(const std::string &name, enum DataType type,
|
|||
return CreateDevTensor(StringToChar(name), type, shape, data, data_len);
|
||||
}
|
||||
|
||||
MSTensor *MSTensor::CreateImageTensor(const std::string &image_file) noexcept {
|
||||
return CreateImageTensor(StringToChar(image_file));
|
||||
MSTensor *MSTensor::CreateTensorFromFile(const std::string &file, enum DataType type,
|
||||
const std::vector<int64_t> &shape) noexcept {
|
||||
return CreateTensorFromFile(StringToChar(file), type, shape);
|
||||
}
|
||||
|
||||
MSTensor *MSTensor::StringsToTensor(const std::string &name, const std::vector<std::string> &str) {
|
||||
|
@ -334,9 +339,7 @@ MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vecto
|
|||
|
||||
std::string MSTensor::Name() const { return CharToString(CharName()); }
|
||||
|
||||
void MSTensor::SetTensorName(const std::string &name) {
|
||||
return SetTensorName(StringToChar(name));
|
||||
}
|
||||
void MSTensor::SetTensorName(const std::string &name) { return SetTensorName(StringToChar(name)); }
|
||||
|
||||
using Key = struct Key {
|
||||
const size_t max_key_len = 32;
|
||||
|
|
|
@ -81,7 +81,7 @@ Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor>
|
|||
return impl_->Predict(inputs, outputs);
|
||||
}
|
||||
|
||||
Status Model::PredictWithPreprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
Status Model::PredictWithPreprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||
|
@ -90,7 +90,7 @@ Status Model::PredictWithPreprocess(const std::vector<MSTensor> &inputs, std::ve
|
|||
return impl_->PredictWithPreprocess(inputs, outputs);
|
||||
}
|
||||
|
||||
Status Model::Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
Status Model::Preprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||
return kMCFailed;
|
||||
|
|
|
@ -45,7 +45,7 @@ Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
|
|||
|
||||
bool ModelImpl::HasPreprocess() { return graph_->graph_data_->GetPreprocess().empty() ? false : true; }
|
||||
|
||||
Status ModelImpl::Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
Status ModelImpl::Preprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs) {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
// Config preprocessor, temporary way to let mindspore.so depends on _c_dataengine
|
||||
std::string dataengine_so_path;
|
||||
|
@ -54,20 +54,53 @@ Status ModelImpl::Preprocess(const std::vector<MSTensor> &inputs, std::vector<MS
|
|||
|
||||
// Run preprocess
|
||||
if (!HasPreprocess()) {
|
||||
MS_LOG(ERROR) << "Attempt to predict with data preprocessor, but no preprocessor is defined in MindIR.";
|
||||
return Status(kMEFailed, "Attempt to predict with data preprocessor, but no preprocessor is defined in MindIR.");
|
||||
}
|
||||
std::vector<std::shared_ptr<dataset::Execute>> preprocessor = graph_->graph_data_->GetPreprocess();
|
||||
|
||||
void *handle = nullptr;
|
||||
void *function = nullptr;
|
||||
dlret = DLSoOpen(dataengine_so_path, "ExecuteRun_C", &handle, &function);
|
||||
CHECK_FAIL_AND_RELEASE(dlret, handle, "Parse ExecuteRun_C failed: " + dlret.GetErrDescription());
|
||||
|
||||
auto ExecuteRun =
|
||||
(void (*)(const std::vector<std::shared_ptr<dataset::Execute>> &, const std::vector<mindspore::MSTensor> &,
|
||||
std::vector<mindspore::MSTensor> *, Status *))(function);
|
||||
ExecuteRun(preprocessor, inputs, outputs, &dlret);
|
||||
CHECK_FAIL_AND_RELEASE(dlret, handle, "Run preprocess failed: " + dlret.GetErrDescription());
|
||||
|
||||
// perform preprocess on each tensor separately
|
||||
std::vector<std::shared_ptr<dataset::Execute>> preprocessor = graph_->graph_data_->GetPreprocess();
|
||||
std::vector<std::vector<MSTensor>> output_unbatch;
|
||||
std::vector<MSTensor> output_batched;
|
||||
for (auto tensor : inputs) {
|
||||
std::vector<MSTensor> temp;
|
||||
ExecuteRun(preprocessor, tensor, &temp, &dlret);
|
||||
CHECK_FAIL_AND_RELEASE(dlret, handle, "Run preprocess failed: " + dlret.GetErrDescription());
|
||||
output_unbatch.push_back(temp);
|
||||
}
|
||||
|
||||
// Construct a tensor with batch dim
|
||||
output_batched.resize(output_unbatch[0].size());
|
||||
for (size_t i = 0; i < output_batched.size(); i++) {
|
||||
std::vector<int64_t> ori_shape = output_unbatch[0][i].Shape();
|
||||
ori_shape.insert(ori_shape.begin(), output_unbatch.size());
|
||||
output_batched[i] = mindspore::MSTensor("outputs", output_unbatch[0][i].DataType(), ori_shape, nullptr,
|
||||
output_unbatch[0][i].DataSize() * output_unbatch.size());
|
||||
}
|
||||
|
||||
// Copy unbatch data into tensor
|
||||
for (size_t i = 0; i < output_unbatch[0].size(); i++) {
|
||||
size_t offset = 0;
|
||||
for (size_t j = 0; j < output_unbatch.size(); j++) {
|
||||
auto ret =
|
||||
memcpy_s(reinterpret_cast<unsigned char *>(output_batched[i].MutableData()) + offset,
|
||||
output_unbatch[j][i].DataSize(), output_unbatch[j][i].MutableData(), output_unbatch[j][i].DataSize());
|
||||
if (ret) {
|
||||
MS_LOG(ERROR) << "Memory copy failed to construct High-Dim Tensor.";
|
||||
return Status(kMEFailed, "Memory copy failed to construct High-Dim Tensor.");
|
||||
}
|
||||
offset += output_unbatch[j][i].DataSize();
|
||||
}
|
||||
}
|
||||
*outputs = output_batched;
|
||||
DLSoClose(handle);
|
||||
return kSuccess;
|
||||
#else
|
||||
|
@ -76,7 +109,8 @@ Status ModelImpl::Preprocess(const std::vector<MSTensor> &inputs, std::vector<MS
|
|||
#endif
|
||||
}
|
||||
|
||||
Status ModelImpl::PredictWithPreprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
Status ModelImpl::PredictWithPreprocess(const std::vector<std::vector<MSTensor>> &inputs,
|
||||
std::vector<MSTensor> *outputs) {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
// Run preprocess
|
||||
std::vector<MSTensor> preprocess_outputs;
|
||||
|
|
|
@ -39,7 +39,8 @@ class ModelImpl {
|
|||
|
||||
virtual Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
|
||||
virtual Status PredictWithPreprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
virtual Status PredictWithPreprocess(const std::vector<std::vector<MSTensor>> &inputs,
|
||||
std::vector<MSTensor> *outputs);
|
||||
|
||||
virtual std::vector<MSTensor> GetInputs() = 0;
|
||||
virtual std::vector<MSTensor> GetOutputs() = 0;
|
||||
|
@ -47,7 +48,7 @@ class ModelImpl {
|
|||
virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0;
|
||||
virtual bool CheckModelSupport(enum ModelType model_type) = 0;
|
||||
|
||||
virtual Status Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
virtual Status Preprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs);
|
||||
|
||||
virtual bool HasPreprocess();
|
||||
|
||||
|
|
|
@ -182,11 +182,10 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
|
|||
err_msg << "Load model failed. The file may be encrypted, please pass in correct key.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
} else {
|
||||
MindIRLoader mindir_loader(false, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode),
|
||||
false);
|
||||
anf_graph = mindir_loader.LoadMindIR(file_path);
|
||||
}
|
||||
MindIRLoader mindir_loader(false, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode),
|
||||
false);
|
||||
anf_graph = mindir_loader.LoadMindIR(file_path);
|
||||
if (anf_graph == nullptr) {
|
||||
err_msg << "Load model failed. Please check the valid of dec_key and dec_mode";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
|
@ -195,7 +194,7 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
|
|||
auto graph_data = std::make_shared<Graph::GraphData>(anf_graph, kMindIR);
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
// Config preprocessor, temporary way to let mindspore.so depends on _c_dataengine
|
||||
std::string preprocessor = LoadPreprocess(file_path);
|
||||
std::vector<std::string> preprocessor = mindir_loader.LoadPreprocess(file_path);
|
||||
if (!preprocessor.empty()) {
|
||||
std::string dataengine_so_path;
|
||||
Status dlret = DLSoPath(&dataengine_so_path);
|
||||
|
@ -205,13 +204,12 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
|
|||
void *function = nullptr;
|
||||
dlret = DLSoOpen(dataengine_so_path, "ParseMindIRPreprocess_C", &handle, &function);
|
||||
CHECK_FAIL_AND_RELEASE(dlret, handle, "Parse ParseMindIRPreprocess_C failed: " + dlret.GetErrDescription());
|
||||
|
||||
auto ParseMindIRPreprocessFun =
|
||||
(void (*)(const std::string &, const std::string &, std::vector<std::shared_ptr<mindspore::dataset::Execute>> *,
|
||||
(void (*)(const std::vector<std::string> &, std::vector<std::shared_ptr<mindspore::dataset::Execute>> *,
|
||||
Status *))(function);
|
||||
|
||||
std::vector<std::shared_ptr<dataset::Execute>> data_graph;
|
||||
ParseMindIRPreprocessFun(preprocessor, "image", &data_graph, &dlret);
|
||||
ParseMindIRPreprocessFun(preprocessor, &data_graph, &dlret);
|
||||
CHECK_FAIL_AND_RELEASE(dlret, handle, "Load preprocess failed: " + dlret.GetErrDescription());
|
||||
DLSoClose(handle);
|
||||
if (!data_graph.empty()) {
|
||||
|
@ -289,7 +287,7 @@ Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelTyp
|
|||
CHECK_FAIL_AND_RELEASE(dlret, handle, "Parse ParseMindIRPreprocess_C failed: " + dlret.GetErrDescription());
|
||||
|
||||
auto ParseMindIRPreprocessFun =
|
||||
(void (*)(const std::string &, const std::string &, std::vector<std::shared_ptr<mindspore::dataset::Execute>> *,
|
||||
(void (*)(const std::vector<std::string> &, std::vector<std::shared_ptr<mindspore::dataset::Execute>> *,
|
||||
Status *))(function);
|
||||
#endif
|
||||
std::vector<Graph> results;
|
||||
|
@ -304,13 +302,12 @@ Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelTyp
|
|||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
auto graph_data = std::make_shared<Graph::GraphData>(anf_graphs[i], kMindIR);
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
// Config preprocessor, temporary way to let mindspore.so depends on _c_dataengine
|
||||
std::string preprocessor = LoadPreprocess(files_path[i]);
|
||||
std::vector<std::string> preprocessor = mindir_loader.LoadPreprocess(files_path[i]);
|
||||
if (!preprocessor.empty()) {
|
||||
std::vector<std::shared_ptr<dataset::Execute>> data_graph;
|
||||
ParseMindIRPreprocessFun(preprocessor, "image", &data_graph, &dlret);
|
||||
ParseMindIRPreprocessFun(preprocessor, &data_graph, &dlret);
|
||||
CHECK_FAIL_AND_RELEASE(dlret, handle, "Load preprocess failed: " + dlret.GetErrDescription());
|
||||
if (!data_graph.empty()) {
|
||||
graph_data->SetPreprocess(data_graph);
|
||||
|
|
|
@ -186,55 +186,72 @@ MSTensor *MSTensor::CreateDevTensor(const std::vector<char> &name, enum DataType
|
|||
}
|
||||
}
|
||||
|
||||
MSTensor *MSTensor::CreateImageTensor(const std::vector<char> &image_file) noexcept {
|
||||
std::string image_file_str = CharToString(image_file);
|
||||
MSTensor *MSTensor::CreateTensorFromFile(const std::vector<char> &file, enum DataType type,
|
||||
const std::vector<int64_t> &shape) noexcept {
|
||||
std::string file_str = CharToString(file);
|
||||
|
||||
try {
|
||||
auto realpath = FileUtils::GetRealPath(image_file_str.c_str());
|
||||
auto realpath = FileUtils::GetRealPath(file_str.c_str());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed, path=" << image_file_str;
|
||||
MS_LOG(ERROR) << "Get real path failed, path=" << file_str;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Read image file
|
||||
auto file = realpath.value();
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "can not find any input file.";
|
||||
auto file_path = realpath.value();
|
||||
if (file_path.empty()) {
|
||||
MS_LOG(ERROR) << "Can not find any input file.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::ifstream ifs(file, std::ios::in | std::ios::binary);
|
||||
std::ifstream ifs(file_path, std::ios::in | std::ios::binary);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "File: " + file + " does not exist.";
|
||||
MS_LOG(ERROR) << "File: " + file_path + " does not exist.";
|
||||
return nullptr;
|
||||
}
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "File: " + file + " open failed.";
|
||||
MS_LOG(ERROR) << "File: " + file_path + " open failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &io_seekg1 = ifs.seekg(0, std::ios::end);
|
||||
if (!io_seekg1.good() || io_seekg1.fail() || io_seekg1.bad()) {
|
||||
ifs.close();
|
||||
MS_LOG(ERROR) << "Failed to seekg file: " + file;
|
||||
MS_LOG(ERROR) << "Failed to seekg file: " + file_path;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t size = static_cast<size_t>(ifs.tellg());
|
||||
MSTensor *ret =
|
||||
new MSTensor(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||
std::vector<int64_t> tensor_shape;
|
||||
tensor_shape = shape.empty() ? std::vector<int64_t>{static_cast<int64_t>(size)} : shape;
|
||||
MSTensor *ret = new MSTensor(file_path, type, tensor_shape, nullptr, size);
|
||||
|
||||
auto &io_seekg2 = ifs.seekg(0, std::ios::beg);
|
||||
if (!io_seekg2.good() || io_seekg2.fail() || io_seekg2.bad()) {
|
||||
ifs.close();
|
||||
MS_LOG(ERROR) << "Failed to seekg file: " + file;
|
||||
MS_LOG(ERROR) << "Failed to seekg file: " + file_path;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::map<enum DataType, size_t> TypeByte = {
|
||||
{DataType::kTypeUnknown, 0}, {DataType::kObjectTypeString, 0}, {DataType::kNumberTypeBool, 1},
|
||||
{DataType::kNumberTypeInt8, 1}, {DataType::kNumberTypeInt16, 2}, {DataType::kNumberTypeInt32, 4},
|
||||
{DataType::kNumberTypeInt64, 8}, {DataType::kNumberTypeUInt8, 1}, {DataType::kNumberTypeUInt16, 2},
|
||||
{DataType::kNumberTypeUInt32, 4}, {DataType::kNumberTypeUInt64, 8}, {DataType::kNumberTypeFloat16, 2},
|
||||
{DataType::kNumberTypeFloat32, 4}, {DataType::kNumberTypeFloat64, 8},
|
||||
};
|
||||
|
||||
if (ret->ElementNum() * TypeByte[type] != size) {
|
||||
ifs.close();
|
||||
MS_LOG(ERROR) << "Tensor data size: " << ret->ElementNum() * TypeByte[type]
|
||||
<< " not match input data length: " << size;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &io_read = ifs.read(reinterpret_cast<char *>(ret->MutableData()), static_cast<std::streamsize>(size));
|
||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||
ifs.close();
|
||||
MS_LOG(ERROR) << "Failed to read file: " + file;
|
||||
MS_LOG(ERROR) << "Failed to read file: " + file_path;
|
||||
return nullptr;
|
||||
}
|
||||
ifs.close();
|
||||
|
|
|
@ -331,52 +331,50 @@ Serdes::InitializeFuncPtr() {
|
|||
return ops_ptr;
|
||||
}
|
||||
|
||||
Status Serdes::ParseMindIRPreprocess(const std::string &dataset_json, const std::string &process_column,
|
||||
Status Serdes::ParseMindIRPreprocess(const std::vector<std::string> &map_json_string,
|
||||
std::vector<std::shared_ptr<mindspore::dataset::Execute>> *data_graph) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!dataset_json.empty(), "Invalid data, no json data in dataset_json.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!map_json_string.empty(), "Invalid data, no json data in map_json_string.");
|
||||
|
||||
nlohmann::json dataset_js;
|
||||
const std::string process_column = "[\"image\"]";
|
||||
MS_LOG(WARNING) << "Only supports parse \"image\" column from dataset object.";
|
||||
|
||||
nlohmann::json map_json;
|
||||
try {
|
||||
dataset_js = nlohmann::json::parse(dataset_json);
|
||||
for (auto &json : map_json_string) {
|
||||
map_json = nlohmann::json::parse(json);
|
||||
if (map_json["input_columns"].dump() == process_column) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception &err) {
|
||||
MS_LOG(ERROR) << "Invalid json content, failed to parse JSON data, error message: " << err.what();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid json content, failed to parse JSON data.");
|
||||
}
|
||||
|
||||
// Note1: We have to consider if pipeline has multibranch, how to deal with this situation?
|
||||
// op1 - map - |
|
||||
// op2 - map - concat - map - ...
|
||||
std::stack<nlohmann::json> reverse_traversal;
|
||||
nlohmann::json dataset_nodes = dataset_js;
|
||||
while (dataset_nodes != nullptr) {
|
||||
reverse_traversal.push(dataset_nodes);
|
||||
if (dataset_nodes["children"].size() > 1) {
|
||||
MS_LOG(WARNING) << "Need to support dataset_node with more than one child.";
|
||||
}
|
||||
dataset_nodes = dataset_nodes["children"][0];
|
||||
if (map_json.empty()) {
|
||||
MS_LOG(ERROR) << "Invalid json content, no JSON data found for given input column: " + process_column;
|
||||
RETURN_STATUS_UNEXPECTED("Invalid json content, no JSON data found for given input column: " + process_column);
|
||||
}
|
||||
|
||||
// Note2: We have to consider if the "image" column does not named with "image", how to select its map ops?
|
||||
// In MindRecord, TFRecord, GeneratorDataset or RenameDataset, it seems that the column names are not fixed.
|
||||
while (!reverse_traversal.empty()) {
|
||||
nlohmann::json node = reverse_traversal.top();
|
||||
reverse_traversal.pop();
|
||||
if (node["op_type"] == "Map") {
|
||||
std::vector<std::shared_ptr<TensorOperation>> tensor_ops;
|
||||
RETURN_IF_NOT_OK(ConstructTensorOps(node["operations"], &tensor_ops));
|
||||
if (node["input_columns"][0] == process_column) {
|
||||
std::vector<std::string> op_names;
|
||||
std::transform(tensor_ops.begin(), tensor_ops.end(), std::back_inserter(op_names),
|
||||
[](const auto &op) { return op->Name(); });
|
||||
MS_LOG(INFO) << "Find valid preprocess operations: " << op_names;
|
||||
data_graph->push_back(std::make_shared<Execute>(tensor_ops));
|
||||
}
|
||||
while (map_json != nullptr) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(map_json["op_type"] == "Map", "Invalid json content, this is not a MapOp.");
|
||||
|
||||
std::vector<std::shared_ptr<TensorOperation>> tensor_ops;
|
||||
RETURN_IF_NOT_OK(ConstructTensorOps(map_json["operations"], &tensor_ops));
|
||||
if (map_json["input_columns"].dump() == process_column) {
|
||||
std::vector<std::string> op_names;
|
||||
std::transform(tensor_ops.begin(), tensor_ops.end(), std::back_inserter(op_names),
|
||||
[](const auto &op) { return op->Name(); });
|
||||
MS_LOG(INFO) << "Find valid preprocess operations: " << op_names;
|
||||
data_graph->push_back(std::make_shared<Execute>(tensor_ops));
|
||||
}
|
||||
map_json = map_json["children"];
|
||||
}
|
||||
|
||||
if (!data_graph->size()) {
|
||||
MS_LOG(WARNING) << "Can not find any valid preprocess operation.";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -386,9 +384,9 @@ Status Serdes::ParseMindIRPreprocess(const std::string &dataset_json, const std:
|
|||
extern "C" {
|
||||
// ParseMindIRPreprocess_C has C-linkage specified, but returns user-defined type 'mindspore::Status'
|
||||
// which is incompatible with C
|
||||
void ParseMindIRPreprocess_C(const std::string &dataset_json, const std::string &process_column,
|
||||
void ParseMindIRPreprocess_C(const std::vector<std::string> &dataset_json,
|
||||
std::vector<std::shared_ptr<mindspore::dataset::Execute>> *data_graph, Status *s) {
|
||||
Status ret = Serdes::ParseMindIRPreprocess(dataset_json, process_column, data_graph);
|
||||
Status ret = Serdes::ParseMindIRPreprocess(dataset_json, data_graph);
|
||||
*s = Status(ret);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -179,11 +179,10 @@ class Serdes {
|
|||
static Status ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result);
|
||||
|
||||
/// \brief helper function to load tensor operations from dataset JSON and construct Execute object.
|
||||
/// \param[in] dataset_json JSON string of dataset.
|
||||
/// \param[in] process_column Select all map operations which process this column.
|
||||
/// \param[in] map_json_string JSON string of dataset.
|
||||
/// \param[out] data_graph Execute object contains tensor operations of map.
|
||||
/// \return Status The status code returned.
|
||||
static Status ParseMindIRPreprocess(const std::string &dataset_json, const std::string &process_column,
|
||||
static Status ParseMindIRPreprocess(const std::vector<std::string> &map_json_string,
|
||||
std::vector<std::shared_ptr<mindspore::dataset::Execute>> *data_graph);
|
||||
|
||||
protected:
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "load_mindir/anf_model_parser.h"
|
||||
|
@ -123,6 +124,64 @@ bool MindIRLoader::ParseGraphProto(mind_ir::GraphProto *graph, const std::string
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::string> MindIRLoader::LoadPreprocess(const std::string &file_name) {
|
||||
if (file_name.length() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
|
||||
return {};
|
||||
}
|
||||
char abs_path_buff[PATH_MAX];
|
||||
|
||||
#ifdef _WIN32
|
||||
_fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
|
||||
#else
|
||||
if (!realpath(file_name.c_str(), abs_path_buff)) {
|
||||
MS_LOG(ERROR) << "Load MindIR get absolute path failed";
|
||||
}
|
||||
#endif
|
||||
|
||||
// Read graph
|
||||
mind_ir::ModelProto origin_model;
|
||||
std::fstream mindir_stream(std::string(std::string(abs_path_buff)), std::ios::in | std::ios::binary);
|
||||
if (!mindir_stream || !origin_model.ParseFromIstream(&mindir_stream)) {
|
||||
MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
|
||||
return {};
|
||||
}
|
||||
|
||||
// Read dataset preprocessor
|
||||
auto preprocessor = origin_model.preprocessor();
|
||||
|
||||
// Separate columns and parse
|
||||
std::vector<std::string> input_columns;
|
||||
for (auto i = 0; i < preprocessor.op_size(); i++) {
|
||||
std::string column = preprocessor.op()[i].input_columns();
|
||||
if (std::find(input_columns.begin(), input_columns.end(), column) == input_columns.end()) {
|
||||
input_columns.push_back(column);
|
||||
}
|
||||
}
|
||||
|
||||
// Each column has one string to indicate its preprocess behaviour
|
||||
std::vector<std::string> map_jsons;
|
||||
for (std::string &column : input_columns) {
|
||||
nlohmann::json dataset_json;
|
||||
nlohmann::json child_dataset_json;
|
||||
for (auto i = preprocessor.op_size() - 1; i >= 0; i--) {
|
||||
if (preprocessor.op()[i].input_columns() == column) {
|
||||
child_dataset_json["input_columns"] = nlohmann::json::parse(preprocessor.op()[i].input_columns());
|
||||
child_dataset_json["op_type"] = nlohmann::json::parse(preprocessor.op()[i].op_type());
|
||||
child_dataset_json["operations"] = nlohmann::json::parse(preprocessor.op()[i].operations());
|
||||
child_dataset_json["output_columns"] = nlohmann::json::parse(preprocessor.op()[i].output_columns());
|
||||
child_dataset_json["project_columns"] = nlohmann::json::parse(preprocessor.op()[i].project_columns());
|
||||
child_dataset_json["offload"] = preprocessor.op()[i].offload();
|
||||
|
||||
dataset_json["children"] = child_dataset_json;
|
||||
child_dataset_json = dataset_json;
|
||||
}
|
||||
}
|
||||
map_jsons.push_back(dataset_json["children"].dump());
|
||||
}
|
||||
return map_jsons;
|
||||
}
|
||||
|
||||
std::vector<FuncGraphPtr> MindIRLoader::LoadMindIRs(std::vector<std::string> file_names) {
|
||||
std::vector<FuncGraphPtr> funcgraph_vec;
|
||||
MS_LOG(DEBUG) << "Load multiple MindIR files.";
|
||||
|
@ -282,32 +341,6 @@ std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) {
|
|||
return buf;
|
||||
}
|
||||
|
||||
std::string LoadPreprocess(const std::string &file_name) {
|
||||
if (file_name.length() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
|
||||
return nullptr;
|
||||
}
|
||||
char abs_path_buff[PATH_MAX];
|
||||
|
||||
#ifdef _WIN32
|
||||
_fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
|
||||
#else
|
||||
if (!realpath(file_name.c_str(), abs_path_buff)) {
|
||||
MS_LOG(ERROR) << "Load MindIR get absolute path failed";
|
||||
}
|
||||
#endif
|
||||
|
||||
// Read graph
|
||||
mind_ir::ModelProto origin_model;
|
||||
std::fstream mindir_stream(std::string(std::string(abs_path_buff)), std::ios::in | std::ios::binary);
|
||||
if (!mindir_stream || !origin_model.ParseFromIstream(&mindir_stream)) {
|
||||
MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
|
||||
return std::string();
|
||||
}
|
||||
|
||||
return origin_model.preprocessor();
|
||||
}
|
||||
|
||||
FuncGraphPtr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
std::string str(buf, buf_size);
|
||||
|
|
|
@ -73,6 +73,7 @@ class MindIRLoader {
|
|||
FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size);
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name);
|
||||
std::vector<FuncGraphPtr> LoadMindIRs(const std::vector<std::string> file_names);
|
||||
std::vector<std::string> LoadPreprocess(const std::string &file_name);
|
||||
|
||||
private:
|
||||
bool ParseModelProto(mind_ir::ModelProto *model, const std::string &path);
|
||||
|
@ -88,7 +89,6 @@ class MindIRLoader {
|
|||
LayoutMap layout_map_;
|
||||
};
|
||||
|
||||
std::string LoadPreprocess(const std::string &file_name);
|
||||
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
|
||||
FuncGraphPtr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -75,12 +75,27 @@ message ModelProto {
|
|||
optional string doc_string = 6;
|
||||
optional GraphProto graph = 7;
|
||||
repeated GraphProto functions = 8; // all the graphs without the main graph.
|
||||
optional string preprocessor = 9; // data graph from MindData.
|
||||
optional PreprocessorProto preprocessor = 9; // data graph from MindData.
|
||||
optional bool little_endian = 10; // bytes order in load device.
|
||||
optional ParallelProto parallel = 11; // information for parallel.
|
||||
}
|
||||
|
||||
|
||||
message PreprocessorProto {
|
||||
repeated PreprocessOpProto op = 1;
|
||||
}
|
||||
|
||||
|
||||
message PreprocessOpProto {
|
||||
optional string input_columns = 1;
|
||||
optional string output_columns = 2;
|
||||
optional string project_columns = 3;
|
||||
optional string op_type = 4;
|
||||
optional string operations = 5;
|
||||
optional bool offload = 6;
|
||||
}
|
||||
|
||||
|
||||
message GraphProto {
|
||||
repeated NodeProto node = 1;
|
||||
optional string name = 2;
|
||||
|
|
|
@ -122,13 +122,13 @@ Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor>
|
|||
return impl_->Predict(inputs, outputs, before, after);
|
||||
}
|
||||
|
||||
Status Model::PredictWithPreprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
Status Model::PredictWithPreprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
Status Model::Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
Status Model::Preprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
|
|
@ -149,7 +149,8 @@ MSTensor *MSTensor::CreateDevTensor(const std::vector<char> &name, enum DataType
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
MSTensor *MSTensor::CreateImageTensor(const std::vector<char> &image_file) noexcept {
|
||||
MSTensor *MSTensor::CreateTensorFromFile(const std::vector<char> &file, enum DataType type,
|
||||
const std::vector<int64_t> &shape) noexcept {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -921,6 +921,14 @@ def check_input_data(*data, data_class):
|
|||
f'but got part data type is {item if item is None else type(item).__name__}.')
|
||||
|
||||
|
||||
def check_input_dataset(*dataset, dataset_type):
|
||||
"""Input dataset check."""
|
||||
for item in dataset:
|
||||
if not isinstance(item, dataset_type):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_output_data(data):
|
||||
"""Output data check."""
|
||||
if data is None:
|
||||
|
|
|
@ -38,7 +38,7 @@ import mindspore
|
|||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import check_input_data, Validator
|
||||
from mindspore._checkparam import check_input_data, check_input_dataset, Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import _cell_graph_executor as _executor
|
||||
from mindspore.common.initializer import initializer
|
||||
|
@ -733,7 +733,11 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|||
|
||||
Args:
|
||||
net (Cell): MindSpore network.
|
||||
inputs (Tensor): Inputs of the `net`, if the network has multiple inputs, incoming tuple(Tensor).
|
||||
inputs (Union[Tensor, tuple(Tensor), Dataset]): While the input type is Tensor, it represents the inputs
|
||||
of the `net`, if the network has multiple inputs, incoming tuple(Tensor). While its type is Dataset,
|
||||
it represents the preprocess behavior of the `net`, data preprocess operations will be serialized.
|
||||
In second situation, you should adjust batch size of dataset script manually which will impact on
|
||||
the batch size of 'net' input. Only supports parse "image" column from dataset currently.
|
||||
file_name (str): File name of the model to be exported.
|
||||
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
|
||||
Default: 'AIR'.
|
||||
|
@ -754,7 +758,6 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|||
- enc_key (byte): Byte type key used for encryption. The valid length is 16, 24, or 32.
|
||||
- enc_mode (str): Specifies the encryption mode, to take effect when enc_key is set.
|
||||
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
|
||||
- dataset (Dataset): Specifies the preprocess methods of network.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -765,7 +768,24 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|||
>>> export(net, Tensor(input_tensor), file_name='lenet', file_format='MINDIR')
|
||||
"""
|
||||
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
||||
check_input_data(*inputs, data_class=Tensor)
|
||||
if check_input_dataset(*inputs, dataset_type=mindspore.dataset.Dataset):
|
||||
if len(inputs) != 1:
|
||||
raise RuntimeError(f"You can only serialize one dataset into MindIR, got " + str(len(inputs)) + " datasets")
|
||||
shapes, types, columns = inputs[0].output_shapes(), inputs[0].output_types(), inputs[0].get_col_names()
|
||||
kwargs['dataset'] = inputs[0]
|
||||
only_support_col = "image"
|
||||
|
||||
inputs = list()
|
||||
for c, s, t in zip(columns, shapes, types):
|
||||
if only_support_col != c:
|
||||
continue
|
||||
inputs.append(Tensor(np.random.uniform(-1.0, 1.0, size=s).astype(t)))
|
||||
if not inputs:
|
||||
raise RuntimeError(f"Only supports parse \"image\" column from dataset now, given dataset has columns: "
|
||||
+ str(columns))
|
||||
inputs = tuple(inputs)
|
||||
else:
|
||||
check_input_data(*inputs, data_class=Tensor)
|
||||
Validator.check_file_name_by_regular(file_name)
|
||||
file_name = os.path.realpath(file_name)
|
||||
net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
|
||||
|
@ -790,8 +810,6 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|||
"""
|
||||
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
||||
check_input_data(*inputs, data_class=Tensor)
|
||||
if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None:
|
||||
check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset)
|
||||
|
||||
if file_format == 'GEIR':
|
||||
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
|
||||
|
@ -966,8 +984,9 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
model.ParseFromString(mindir_stream)
|
||||
|
||||
if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None:
|
||||
check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset)
|
||||
dataset = kwargs['dataset']
|
||||
model.preprocessor = json.dumps(dataset.to_json(), indent=2)
|
||||
_save_dataset_to_mindir(model, dataset)
|
||||
|
||||
save_together = _save_together(net_dict, model)
|
||||
is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
|
||||
|
@ -1019,6 +1038,27 @@ def _save_together(net_dict, model):
|
|||
return True
|
||||
|
||||
|
||||
def _save_dataset_to_mindir(model, dataset):
|
||||
"""Save dataset preprocess operations into mindir model."""
|
||||
dataset_json = dataset.to_json()
|
||||
reverse_dataset = []
|
||||
while dataset_json:
|
||||
reverse_dataset = [dataset_json] + reverse_dataset
|
||||
if len(dataset_json['children']) > 1:
|
||||
logger.warning("Need to support dataset_node with more than one child, using child 0 as default.")
|
||||
dataset_json = dataset_json['children'][0] if dataset_json['children'] else []
|
||||
|
||||
for op in reverse_dataset:
|
||||
if op['op_type'] == 'Map':
|
||||
model.preprocessor.op.add()
|
||||
model.preprocessor.op[-1].input_columns = json.dumps(op['input_columns'])
|
||||
model.preprocessor.op[-1].output_columns = json.dumps(op['output_columns'])
|
||||
model.preprocessor.op[-1].project_columns = json.dumps(op['project_columns'])
|
||||
model.preprocessor.op[-1].op_type = json.dumps(op['op_type'])
|
||||
model.preprocessor.op[-1].operations = json.dumps(op['operations'])
|
||||
model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
|
||||
|
||||
|
||||
def quant_mode_manage(func):
|
||||
"""
|
||||
Inherit the quant_mode in old version.
|
||||
|
|
|
@ -2,7 +2,12 @@ import os
|
|||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as CV
|
||||
import mindspore.dataset.transforms.c_transforms as CT
|
||||
from mindspore.dataset.vision import Inter
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
|
@ -28,6 +33,35 @@ def fc_with_initialize(input_channels, out_channels):
|
|||
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
def create_dataset():
|
||||
# define dataset
|
||||
mnist_ds = ds.MnistDataset("../data/dataset/testMnistData")
|
||||
|
||||
resize_height, resize_width = 32, 32
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
rescale_nml = 1 / 0.3081
|
||||
shift_nml = -1 * 0.1307 / 0.3081
|
||||
|
||||
# define map operations
|
||||
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
|
||||
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
hwc2chw_op = CV.HWC2CHW()
|
||||
type_cast_op = CT.TypeCast(mstype.int32)
|
||||
|
||||
# apply map operations on images
|
||||
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label")
|
||||
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image")
|
||||
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image")
|
||||
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image")
|
||||
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image")
|
||||
|
||||
# apply DatasetOps
|
||||
mnist_ds = mnist_ds.batch(batch_size=32, drop_remainder=True)
|
||||
|
||||
return mnist_ds
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
def __init__(self):
|
||||
super(LeNet5, self).__init__()
|
||||
|
@ -85,6 +119,11 @@ class TrainOneStepCell(nn.Cell):
|
|||
|
||||
|
||||
def test_export_lenet_grad_mindir():
|
||||
"""
|
||||
Feature: Export LeNet to MindIR
|
||||
Description: Test export API to save network into MindIR
|
||||
Expectation: save successfully
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
network = LeNet5()
|
||||
network.set_train()
|
||||
|
@ -96,3 +135,21 @@ def test_export_lenet_grad_mindir():
|
|||
verify_name = file_name + ".mindir"
|
||||
assert os.path.exists(verify_name)
|
||||
os.remove(verify_name)
|
||||
|
||||
|
||||
def test_export_lenet_with_dataset():
|
||||
"""
|
||||
Feature: Export LeNet with data preprocess to MindIR
|
||||
Description: Test export API to save network and dataset into MindIR
|
||||
Expectation: save successfully
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
network = LeNet5()
|
||||
network.set_train()
|
||||
dataset = create_dataset()
|
||||
file_name = "lenet_preprocess"
|
||||
|
||||
export(network, dataset, file_name=file_name, file_format='MINDIR')
|
||||
verify_name = file_name + ".mindir"
|
||||
assert os.path.exists(verify_name)
|
||||
os.remove(verify_name)
|
||||
|
|
Loading…
Reference in New Issue