diff --git a/mindspore/lite/src/train/train_export.cc b/mindspore/lite/src/train/train_export.cc index 62eab151512..fe1dbef810d 100644 --- a/mindspore/lite/src/train/train_export.cc +++ b/mindspore/lite/src/train/train_export.cc @@ -51,12 +51,14 @@ schema::QuantType TrainExport::GetNodeQuantType(const kernel::LiteKernel *kernel } void TrainExport::TagQuantizedNodes() { - for (auto &node : meta_graph_->nodes) { - if (node->quantType != schema::QuantType_QUANT_WEIGHT) { - for (auto t_idx : node->inputIndex) { - if ((meta_graph_->allTensors.at(t_idx)->nodeType == NodeType_ValueNode) && - (meta_graph_->allTensors.at(t_idx)->quantParams.size() > 0)) { - node->quantType = schema::QuantType_QUANT_WEIGHT; + if (quant_type_ == QT_WEIGHT) { + for (auto &node : meta_graph_->nodes) { + if (node->quantType != schema::QuantType_QUANT_WEIGHT) { + for (auto t_idx : node->inputIndex) { + if ((meta_graph_->allTensors.at(t_idx)->nodeType == NodeType_ValueNode) && + (meta_graph_->allTensors.at(t_idx)->quantParams.size() > 0)) { + node->quantType = schema::QuantType_QUANT_WEIGHT; + } } } } @@ -144,6 +146,20 @@ Model::Node *TrainExport::FindNode(const mindspore::kernel::LiteKernel *kernel, return *it; } +int TrainExport::CreateAndAddCNode(const mindspore::kernel::LiteKernel *kernel, std::vector inputIndex, + std::vector outputIndex, const Model *model) { + auto cnode = CreateCNode(kernel, inputIndex, outputIndex, model); + if (cnode == nullptr) { + MS_LOG(ERROR) << "failed to create cnode"; + return RET_ERROR; + } + meta_graph_->nodes.emplace_back(std::move(cnode)); + if (!meta_graph_->subGraph.empty()) { + meta_graph_->subGraph[0]->nodeIndices.push_back(meta_graph_->nodes.size() - 1); + } + return RET_OK; +} + std::unique_ptr TrainExport::CreateCNode(const mindspore::kernel::LiteKernel *kernel, std::vector inputIndex, std::vector outputIndex, const Model *model) { @@ -255,12 +271,18 @@ int TrainExport::AddTransformNode() { return RET_ERROR; } meta_graph_->allTensors.emplace_back(std::move(tensorConst)); // last_id + if (!meta_graph_->subGraph.empty()) { + meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1); + } auto tensorT = CreateTransformTensor(it.second); if (tensorT == nullptr) { MS_LOG(ERROR) << "error in create tensor"; return RET_ERROR; } meta_graph_->allTensors.emplace_back(std::move(tensorT)); // last_id + 1 + if (!meta_graph_->subGraph.empty()) { + meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1); + } std::vector in_idx = {static_cast(it.second), static_cast(last_id)}; std::vector out_idx = {static_cast(last_id + 1)}; reconnect[it.first] = last_id + 1; @@ -270,11 +292,20 @@ int TrainExport::AddTransformNode() { return RET_ERROR; } meta_graph_->nodes.emplace_back(std::move(cnode)); + if (!meta_graph_->subGraph.empty()) { + meta_graph_->subGraph[0]->nodeIndices.push_back(meta_graph_->nodes.size() - 1); + } } connect_ = reconnect; return RET_OK; } +void TrainExport::PrepareRemap(int offset) { + for (auto it : connect_) { + remap_[it.first + offset] = it.second; + } +} + int TrainExport::ExportNet(const std::vector &kernels, const std::vector &tensors, const std::vector &output_names, const Model *model, @@ -284,17 +315,13 @@ int TrainExport::ExportNet(const std::vector &k int offset = meta_graph_->allTensors.size(); int tensor_idx = offset; quant_type_ = quant_type; - if (meta_graph_ == nullptr) { int status = ExportInit(model->name_, model->version_); if (status != RET_OK) { return status; } } - // prepare mapping for connection - for (auto it : connect_) { - remap_[it.first + offset] = it.second; - } + PrepareRemap(offset); for (const auto kernel : kernels) { std::vector in_idx, out_idx; @@ -332,12 +359,11 @@ int TrainExport::ExportNet(const std::vector &k out_set.insert(it->second); } } - auto cnode = CreateCNode(kernel, in_idx, out_idx, model); - if (cnode == nullptr) { + auto ret = CreateAndAddCNode(kernel, in_idx, out_idx, model); + if (ret != RET_OK) { MS_LOG(ERROR) << "failed to create cnode"; - return RET_ERROR; + return ret; } - meta_graph_->nodes.emplace_back(std::move(cnode)); } for (auto id : map_index) { size_t pid = id - offset; @@ -351,6 +377,9 @@ int TrainExport::ExportNet(const std::vector &k if (out_set.find(remap_[id]) == out_set.end()) { if ((tensorT->nodeType == NodeType_ValueNode) && (tensorT->data.size() == 0)) { meta_graph_->inputIndex.push_back(remap_[id]); + if (!meta_graph_->subGraph.empty()) { + meta_graph_->subGraph[0]->inputIndices.push_back(remap_[id]); + } } } // find output tensor @@ -361,11 +390,12 @@ int TrainExport::ExportNet(const std::vector &k } } meta_graph_->allTensors.emplace_back(std::move(tensorT)); + if (!meta_graph_->subGraph.empty()) { + meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1); + } } + TagQuantizedNodes(); // do another loop to mark QUANT_WEIGHT_NODES - if (quant_type_ == QT_WEIGHT) { // do another loop to mark QUANT_WEIGHT_NODES - TagQuantizedNodes(); - } return RET_OK; } diff --git a/mindspore/lite/src/train/train_export.h b/mindspore/lite/src/train/train_export.h index 8a624918681..cb19d766fd6 100644 --- a/mindspore/lite/src/train/train_export.h +++ b/mindspore/lite/src/train/train_export.h @@ -57,11 +57,14 @@ class TrainExport { std::vector out_idx_; std::map remap_; std::unordered_map connect_; // connection map (backbone tenor id-> head tensor id) + void PrepareRemap(int offset); Model::Node *FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model); std::unique_ptr CreateTensor(const Tensor *tensor, schema::Tensor *scTensor); std::unique_ptr CreateCNode(const mindspore::kernel::LiteKernel *kernel, std::vector inputIndex, std::vector outputIndex, const Model *model); + int CreateAndAddCNode(const mindspore::kernel::LiteKernel *kernel, std::vector inputIndex, + std::vector outputIndex, const Model *model); std::unique_ptr CreateTransformNode(std::vector inputIndex, std::vector outputIndex, size_t id); std::unique_ptr CreateTransformTensor(size_t id); diff --git a/mindspore/lite/test/st/scripts/run_net_train.sh b/mindspore/lite/test/st/scripts/run_net_train.sh index 420fa216fae..a98cc67e087 100755 --- a/mindspore/lite/test/st/scripts/run_net_train.sh +++ b/mindspore/lite/test/st/scripts/run_net_train.sh @@ -135,7 +135,7 @@ function Run_x86() { bb_model_file="${ms_models_path}/${model_name}_bb.ms" suffix_print="_transfer" export_file="${ms_models_path}/${model_name}_tod_head" - inference_file="" + inference_file="${ms_models_path}/${model_name}_infer" fi if [ ! -z "$inference_file" ]; then rm -f ${inference_file}"*" @@ -267,8 +267,8 @@ function Run_arm() { model_file="${model_name}_head.ms" bb_model_file="${model_name}_bb.ms" suffix_print="_transfer" - export_file="${tmp_dir}/${model_name}_tod_head.ms" - inference_file="" + export_file="${tmp_dir}/${model_name}_tod_head" + inference_file="${tmp_dir}/${model_name}_infer" fi # run benchmark_train test without clib data echo ${model_name} >> "${run_arm_log_file}"