avoid same name

This commit is contained in:
liuyu 2021-05-10 17:26:31 +08:00
parent ce57912f7f
commit e2a0893db9
8 changed files with 49 additions and 43 deletions

View File

@ -64,6 +64,7 @@ OpParameter *PopulateSpliceParameter(const void *prim) {
auto forward_indexes = value->forward_indexes(); auto forward_indexes = value->forward_indexes();
if (forward_indexes == nullptr) { if (forward_indexes == nullptr) {
MS_LOG(ERROR) << "forward_indexes is nullptr"; MS_LOG(ERROR) << "forward_indexes is nullptr";
free(param->context_);
free(param); free(param);
return nullptr; return nullptr;
} }

View File

@ -44,6 +44,7 @@ OpParameter *PopulateBatchToSpaceParameter(const void *prim) {
auto block_shape = batch_to_space_prim->blockShape(); auto block_shape = batch_to_space_prim->blockShape();
if (block_shape == nullptr) { if (block_shape == nullptr) {
MS_LOG(ERROR) << "block_shape is nullptr"; MS_LOG(ERROR) << "block_shape is nullptr";
free(batch_space_param);
return nullptr; return nullptr;
} }
if (block_shape->size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { if (block_shape->size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
@ -55,6 +56,7 @@ OpParameter *PopulateBatchToSpaceParameter(const void *prim) {
auto crops = batch_to_space_prim->crops(); auto crops = batch_to_space_prim->crops();
if (crops == nullptr) { if (crops == nullptr) {
MS_LOG(ERROR) << "crops is nullptr"; MS_LOG(ERROR) << "crops is nullptr";
free(batch_space_param);
return nullptr; return nullptr;
} }
if (crops->size() != COMM_SHAPE_SIZE) { if (crops->size() != COMM_SHAPE_SIZE) {

View File

@ -36,6 +36,11 @@ OpParameter *PopulateConstantOfShapeParameter(const void *prim) {
memset(param, 0, sizeof(ConstantOfShapeParameter)); memset(param, 0, sizeof(ConstantOfShapeParameter));
param->op_parameter_.type_ = schema::PrimitiveType_ConstantOfShape; param->op_parameter_.type_ = schema::PrimitiveType_ConstantOfShape;
auto value = constant_of_shape_prim->value(); auto value = constant_of_shape_prim->value();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
free(param);
return nullptr;
}
param->data_type_ = constant_of_shape_prim->dataType(); param->data_type_ = constant_of_shape_prim->dataType();
if (value->size() == 0 || value->size() > 1) { if (value->size() == 0 || value->size() > 1) {
MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1."; MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1.";

View File

@ -39,6 +39,7 @@ OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) {
auto block_sizes = space_to_batch_nd_prim->blockShape(); auto block_sizes = space_to_batch_nd_prim->blockShape();
if (block_sizes == nullptr) { if (block_sizes == nullptr) {
MS_LOG(ERROR) << "block_sizes is nullptr"; MS_LOG(ERROR) << "block_sizes is nullptr";
free(space_batch_param_nd);
return nullptr; return nullptr;
} }
space_batch_param_nd->m_ = block_sizes->size(); space_batch_param_nd->m_ = block_sizes->size();
@ -51,6 +52,7 @@ OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) {
auto paddings = space_to_batch_nd_prim->paddings(); auto paddings = space_to_batch_nd_prim->paddings();
if (paddings == nullptr) { if (paddings == nullptr) {
MS_LOG(ERROR) << "paddings is nullptr"; MS_LOG(ERROR) << "paddings is nullptr";
free(space_batch_param_nd);
return nullptr; return nullptr;
} }
if (((size_t)paddings->size()) > std::numeric_limits<size_t>::max() / sizeof(int)) { if (((size_t)paddings->size()) > std::numeric_limits<size_t>::max() / sizeof(int)) {

View File

@ -39,6 +39,7 @@ OpParameter *PopulateSpaceToBatchParameter(const void *prim) {
auto block_sizes = space_to_batch_prim->blockShape(); // maybe error auto block_sizes = space_to_batch_prim->blockShape(); // maybe error
if (block_sizes == nullptr) { if (block_sizes == nullptr) {
MS_LOG(ERROR) << "block_sizes is nullptr"; MS_LOG(ERROR) << "block_sizes is nullptr";
free(space_batch_param);
return nullptr; return nullptr;
} }
space_batch_param->m_ = block_sizes->size(); space_batch_param->m_ = block_sizes->size();
@ -51,6 +52,7 @@ OpParameter *PopulateSpaceToBatchParameter(const void *prim) {
auto paddings = space_to_batch_prim->paddings(); auto paddings = space_to_batch_prim->paddings();
if (paddings == nullptr) { if (paddings == nullptr) {
MS_LOG(ERROR) << "paddings is nullptr"; MS_LOG(ERROR) << "paddings is nullptr";
free(space_batch_param);
return nullptr; return nullptr;
} }
if (((size_t)paddings->size()) > std::numeric_limits<size_t>::max() / sizeof(int)) { if (((size_t)paddings->size()) > std::numeric_limits<size_t>::max() / sizeof(int)) {

View File

@ -39,6 +39,7 @@ OpParameter *PopulateUnsqueezeParameter(const void *prim) {
auto flatAxis = unsqueeze_prim->axis(); auto flatAxis = unsqueeze_prim->axis();
if (flatAxis == nullptr) { if (flatAxis == nullptr) {
MS_LOG(ERROR) << "flatAxis is nullptr"; MS_LOG(ERROR) << "flatAxis is nullptr";
free(unsqueeze_param);
return nullptr; return nullptr;
} }
unsqueeze_param->num_dim_ = flatAxis->size(); unsqueeze_param->num_dim_ = flatAxis->size();

View File

@ -21,6 +21,7 @@
#include <functional> #include <functional>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <algorithm>
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "mindspore/core/ir/primitive.h" #include "mindspore/core/ir/primitive.h"
@ -421,18 +422,10 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode) { int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode) {
MS_ASSERT(input_anode != nullptr && output_cnode != nullptr); MS_ASSERT(input_anode != nullptr && output_cnode != nullptr);
auto input_name = input_anode->fullname_with_scope();
if (this->train_flag_) { if (this->train_flag_) {
bool found = false; auto key = std::make_pair(input_anode, 0);
if (node_id_map_.find(input_name) != node_id_map_.end()) { if (node_id_map_.find(key) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); output_cnode->inputIndex.emplace_back(node_id_map_[key]);
found = true;
}
if (!found) {
auto input_index_key = input_name + "_o:" + std::to_string(0);
if (node_id_map_.find(input_index_key) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]);
}
} }
return RET_OK; return RET_OK;
} }
@ -444,20 +437,15 @@ int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema
} }
auto elements = tuple->elements(); auto elements = tuple->elements();
for (size_t i = 0; i < elements.size(); i++) { for (size_t i = 0; i < elements.size(); i++) {
if (elements.size() == 1) { auto key = std::make_pair(input_anode, i);
if (node_id_map_.find(input_name) != node_id_map_.end()) { if (node_id_map_.find(key) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); output_cnode->inputIndex.emplace_back(node_id_map_[key]);
}
} else {
std::string name = input_name + "_o:" + std::to_string(i);
if (node_id_map_.find(name) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[name]);
}
} }
} }
} else { } else {
if (node_id_map_.find(input_name) != node_id_map_.end()) { auto key = std::make_pair(input_anode, 0);
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); if (node_id_map_.find(key) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[key]);
} }
} }
return RET_OK; return RET_OK;
@ -490,16 +478,16 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
MS_LOG(ERROR) << "cast to ValueNode failed"; MS_LOG(ERROR) << "cast to ValueNode failed";
return RET_ERROR; return RET_ERROR;
} }
auto input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + auto idx = value_node->value()->type()->number_type() == kNumberTypeInt64 ? GetValue<int64_t>(value_node->value())
std::to_string(value_node->value()->type()->number_type() == kNumberTypeInt64 : GetValue<int>(value_node->value());
? GetValue<int64_t>(value_node->value()) auto key = std::make_pair(get_item_input_cnode, idx);
: GetValue<int>(value_node->value())); auto iter = node_id_map_.find(key);
auto iter = node_id_map_.find(input_index_key);
if (iter == node_id_map_.end()) { if (iter == node_id_map_.end()) {
input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0 key = std::make_pair(get_item_input_cnode, 0); // try name with 0
iter = node_id_map_.find(input_index_key); iter = node_id_map_.find(key);
if (iter == node_id_map_.end()) { if (iter == node_id_map_.end()) {
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; MS_LOG(ERROR) << "Can not find get_item output tensor "
<< get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(idx);
return RET_ERROR; return RET_ERROR;
} }
} }
@ -513,9 +501,9 @@ int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, cons
schema::CNodeT *op_node) { schema::CNodeT *op_node) {
auto param_node = cnode->input(index)->cast<ParameterPtr>(); auto param_node = cnode->input(index)->cast<ParameterPtr>();
MS_ASSERT(param_node != nullptr); MS_ASSERT(param_node != nullptr);
std::string input_name = param_node->fullname_with_scope(); auto key = std::make_pair(param_node, 0);
if (node_id_map_.find(input_name) != node_id_map_.end()) { if (node_id_map_.find(key) != node_id_map_.end()) {
op_node->inputIndex.emplace_back(node_id_map_[param_node->name()]); op_node->inputIndex.emplace_back(node_id_map_[key]);
return RET_OK; return RET_OK;
} }
DataInfo data_info; DataInfo data_info;
@ -532,7 +520,7 @@ int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, cons
schema_tensor->data = data_info.data_; schema_tensor->data = data_info.data_;
schema_tensor->enableHuffmanCode = data_info.enable_huffman_code_; schema_tensor->enableHuffmanCode = data_info.enable_huffman_code_;
node_id_map_[input_name] = meta_graphT->allTensors.size(); node_id_map_[key] = meta_graphT->allTensors.size();
op_node->inputIndex.emplace_back(meta_graphT->allTensors.size()); op_node->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(schema_tensor)); meta_graphT->allTensors.emplace_back(std::move(schema_tensor));
return RET_OK; return RET_OK;
@ -556,7 +544,9 @@ int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, cons
schema_tensor->dataType = data_info.data_type_; schema_tensor->dataType = data_info.data_type_;
schema_tensor->dims = data_info.shape_; schema_tensor->dims = data_info.shape_;
schema_tensor->data = data_info.data_; schema_tensor->data = data_info.data_;
node_id_map_[cnode->input(index)->fullname_with_scope()] = meta_graphT->allTensors.size();
auto key = std::make_pair(cnode->input(index), 0);
node_id_map_[key] = meta_graphT->allTensors.size();
op_node->inputIndex.emplace_back(meta_graphT->allTensors.size()); op_node->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(schema_tensor)); meta_graphT->allTensors.emplace_back(std::move(schema_tensor));
return RET_OK; return RET_OK;
@ -628,18 +618,18 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
} }
ms_tensor->nodeType = NodeType_CNode; ms_tensor->nodeType = NodeType_CNode;
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
auto key = std::make_pair(cnode, i);
if (train_flag_) { if (train_flag_) {
std::string name = cnode_name + "_o:" + std::to_string(i); node_id_map_[key] = meta_graphT->allTensors.size();
node_id_map_[name] = meta_graphT->allTensors.size();
meta_graphT->allTensors.emplace_back(ms_tensor); meta_graphT->allTensors.emplace_back(ms_tensor);
} else { } else {
if (elements.size() == 1) { if (elements.size() == 1) {
node_id_map_[cnode_name] = meta_graphT->allTensors.size(); key = std::make_pair(cnode, 0);
node_id_map_[key] = meta_graphT->allTensors.size();
ms_tensor->name = cnode_name; ms_tensor->name = cnode_name;
} else { } else {
std::string name = cnode_name + "_o:" + std::to_string(i); node_id_map_[key] = meta_graphT->allTensors.size();
node_id_map_[name] = meta_graphT->allTensors.size(); ms_tensor->name = cnode_name + "_o:" + std::to_string(i);
ms_tensor->name = name;
} }
if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) { if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
@ -673,7 +663,9 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
ms_tensor->nodeType = NodeType_CNode; ms_tensor->nodeType = NodeType_CNode;
ms_tensor->name = cnode_name; ms_tensor->name = cnode_name;
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
node_id_map_[cnode_name] = meta_graphT->allTensors.size();
auto key = std::make_pair(cnode, 0);
node_id_map_[key] = meta_graphT->allTensors.size();
meta_graphT->allTensors.emplace_back(ms_tensor); meta_graphT->allTensors.emplace_back(ms_tensor);
} }
} }

View File

@ -21,6 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <utility>
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "ops/primitive_c.h" #include "ops/primitive_c.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
@ -74,7 +75,7 @@ class AnfExporter {
bool HasExported(const FuncGraphPtr &func_graph); bool HasExported(const FuncGraphPtr &func_graph);
private: private:
std::map<std::string, int> node_id_map_; std::map<std::pair<AnfNodePtr, int>, int> node_id_map_;
std::vector<schema::CNodeT *> graph_input_nodes_; std::vector<schema::CNodeT *> graph_input_nodes_;
// The first item is FuncGraph which has been exported, the second item is the subgraph index in meta_graph // The first item is FuncGraph which has been exported, the second item is the subgraph index in meta_graph
std::map<FuncGraphPtr, int> fg_subgraph_map_; std::map<FuncGraphPtr, int> fg_subgraph_map_;