avoid same name

This commit is contained in:
liuyu 2021-05-10 17:26:31 +08:00
parent ce57912f7f
commit e2a0893db9
8 changed files with 49 additions and 43 deletions

View File

@ -64,6 +64,7 @@ OpParameter *PopulateSpliceParameter(const void *prim) {
auto forward_indexes = value->forward_indexes();
if (forward_indexes == nullptr) {
MS_LOG(ERROR) << "forward_indexes is nullptr";
free(param->context_);
free(param);
return nullptr;
}

View File

@ -44,6 +44,7 @@ OpParameter *PopulateBatchToSpaceParameter(const void *prim) {
auto block_shape = batch_to_space_prim->blockShape();
if (block_shape == nullptr) {
MS_LOG(ERROR) << "block_shape is nullptr";
free(batch_space_param);
return nullptr;
}
if (block_shape->size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
@ -55,6 +56,7 @@ OpParameter *PopulateBatchToSpaceParameter(const void *prim) {
auto crops = batch_to_space_prim->crops();
if (crops == nullptr) {
MS_LOG(ERROR) << "crops is nullptr";
free(batch_space_param);
return nullptr;
}
if (crops->size() != COMM_SHAPE_SIZE) {

View File

@ -36,6 +36,11 @@ OpParameter *PopulateConstantOfShapeParameter(const void *prim) {
memset(param, 0, sizeof(ConstantOfShapeParameter));
param->op_parameter_.type_ = schema::PrimitiveType_ConstantOfShape;
auto value = constant_of_shape_prim->value();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
free(param);
return nullptr;
}
param->data_type_ = constant_of_shape_prim->dataType();
if (value->size() == 0 || value->size() > 1) {
MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1.";

View File

@ -39,6 +39,7 @@ OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) {
auto block_sizes = space_to_batch_nd_prim->blockShape();
if (block_sizes == nullptr) {
MS_LOG(ERROR) << "block_sizes is nullptr";
free(space_batch_param_nd);
return nullptr;
}
space_batch_param_nd->m_ = block_sizes->size();
@ -51,6 +52,7 @@ OpParameter *PopulateSpaceToBatchNDParameter(const void *prim) {
auto paddings = space_to_batch_nd_prim->paddings();
if (paddings == nullptr) {
MS_LOG(ERROR) << "paddings is nullptr";
free(space_batch_param_nd);
return nullptr;
}
if (((size_t)paddings->size()) > std::numeric_limits<size_t>::max() / sizeof(int)) {

View File

@ -39,6 +39,7 @@ OpParameter *PopulateSpaceToBatchParameter(const void *prim) {
auto block_sizes = space_to_batch_prim->blockShape(); // maybe error
if (block_sizes == nullptr) {
MS_LOG(ERROR) << "block_sizes is nullptr";
free(space_batch_param);
return nullptr;
}
space_batch_param->m_ = block_sizes->size();
@ -51,6 +52,7 @@ OpParameter *PopulateSpaceToBatchParameter(const void *prim) {
auto paddings = space_to_batch_prim->paddings();
if (paddings == nullptr) {
MS_LOG(ERROR) << "paddings is nullptr";
free(space_batch_param);
return nullptr;
}
if (((size_t)paddings->size()) > std::numeric_limits<size_t>::max() / sizeof(int)) {

View File

@ -39,6 +39,7 @@ OpParameter *PopulateUnsqueezeParameter(const void *prim) {
auto flatAxis = unsqueeze_prim->axis();
if (flatAxis == nullptr) {
MS_LOG(ERROR) << "flatAxis is nullptr";
free(unsqueeze_param);
return nullptr;
}
unsqueeze_param->num_dim_ = flatAxis->size();

View File

@ -21,6 +21,7 @@
#include <functional>
#include <utility>
#include <vector>
#include <algorithm>
#include "tools/converter/converter_flags.h"
#include "abstract/abstract_value.h"
#include "mindspore/core/ir/primitive.h"
@ -421,18 +422,10 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode) {
MS_ASSERT(input_anode != nullptr && output_cnode != nullptr);
auto input_name = input_anode->fullname_with_scope();
if (this->train_flag_) {
bool found = false;
if (node_id_map_.find(input_name) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
found = true;
}
if (!found) {
auto input_index_key = input_name + "_o:" + std::to_string(0);
if (node_id_map_.find(input_index_key) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]);
}
auto key = std::make_pair(input_anode, 0);
if (node_id_map_.find(key) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[key]);
}
return RET_OK;
}
@ -444,20 +437,15 @@ int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema
}
auto elements = tuple->elements();
for (size_t i = 0; i < elements.size(); i++) {
if (elements.size() == 1) {
if (node_id_map_.find(input_name) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
}
} else {
std::string name = input_name + "_o:" + std::to_string(i);
if (node_id_map_.find(name) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[name]);
}
auto key = std::make_pair(input_anode, i);
if (node_id_map_.find(key) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[key]);
}
}
} else {
if (node_id_map_.find(input_name) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
auto key = std::make_pair(input_anode, 0);
if (node_id_map_.find(key) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[key]);
}
}
return RET_OK;
@ -490,16 +478,16 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
MS_LOG(ERROR) << "cast to ValueNode failed";
return RET_ERROR;
}
auto input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" +
std::to_string(value_node->value()->type()->number_type() == kNumberTypeInt64
? GetValue<int64_t>(value_node->value())
: GetValue<int>(value_node->value()));
auto iter = node_id_map_.find(input_index_key);
auto idx = value_node->value()->type()->number_type() == kNumberTypeInt64 ? GetValue<int64_t>(value_node->value())
: GetValue<int>(value_node->value());
auto key = std::make_pair(get_item_input_cnode, idx);
auto iter = node_id_map_.find(key);
if (iter == node_id_map_.end()) {
input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0
iter = node_id_map_.find(input_index_key);
key = std::make_pair(get_item_input_cnode, 0); // try name with 0
iter = node_id_map_.find(key);
if (iter == node_id_map_.end()) {
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key;
MS_LOG(ERROR) << "Can not find get_item output tensor "
<< get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(idx);
return RET_ERROR;
}
}
@ -513,9 +501,9 @@ int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, cons
schema::CNodeT *op_node) {
auto param_node = cnode->input(index)->cast<ParameterPtr>();
MS_ASSERT(param_node != nullptr);
std::string input_name = param_node->fullname_with_scope();
if (node_id_map_.find(input_name) != node_id_map_.end()) {
op_node->inputIndex.emplace_back(node_id_map_[param_node->name()]);
auto key = std::make_pair(param_node, 0);
if (node_id_map_.find(key) != node_id_map_.end()) {
op_node->inputIndex.emplace_back(node_id_map_[key]);
return RET_OK;
}
DataInfo data_info;
@ -532,7 +520,7 @@ int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, cons
schema_tensor->data = data_info.data_;
schema_tensor->enableHuffmanCode = data_info.enable_huffman_code_;
node_id_map_[input_name] = meta_graphT->allTensors.size();
node_id_map_[key] = meta_graphT->allTensors.size();
op_node->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(schema_tensor));
return RET_OK;
@ -556,7 +544,9 @@ int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, cons
schema_tensor->dataType = data_info.data_type_;
schema_tensor->dims = data_info.shape_;
schema_tensor->data = data_info.data_;
node_id_map_[cnode->input(index)->fullname_with_scope()] = meta_graphT->allTensors.size();
auto key = std::make_pair(cnode->input(index), 0);
node_id_map_[key] = meta_graphT->allTensors.size();
op_node->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(schema_tensor));
return RET_OK;
@ -628,18 +618,18 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
}
ms_tensor->nodeType = NodeType_CNode;
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
auto key = std::make_pair(cnode, i);
if (train_flag_) {
std::string name = cnode_name + "_o:" + std::to_string(i);
node_id_map_[name] = meta_graphT->allTensors.size();
node_id_map_[key] = meta_graphT->allTensors.size();
meta_graphT->allTensors.emplace_back(ms_tensor);
} else {
if (elements.size() == 1) {
node_id_map_[cnode_name] = meta_graphT->allTensors.size();
key = std::make_pair(cnode, 0);
node_id_map_[key] = meta_graphT->allTensors.size();
ms_tensor->name = cnode_name;
} else {
std::string name = cnode_name + "_o:" + std::to_string(i);
node_id_map_[name] = meta_graphT->allTensors.size();
ms_tensor->name = name;
node_id_map_[key] = meta_graphT->allTensors.size();
ms_tensor->name = cnode_name + "_o:" + std::to_string(i);
}
if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
@ -673,7 +663,9 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
ms_tensor->nodeType = NodeType_CNode;
ms_tensor->name = cnode_name;
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
node_id_map_[cnode_name] = meta_graphT->allTensors.size();
auto key = std::make_pair(cnode, 0);
node_id_map_[key] = meta_graphT->allTensors.size();
meta_graphT->allTensors.emplace_back(ms_tensor);
}
}

View File

@ -21,6 +21,7 @@
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include "schema/inner/model_generated.h"
#include "ops/primitive_c.h"
#include "ir/func_graph.h"
@ -74,7 +75,7 @@ class AnfExporter {
bool HasExported(const FuncGraphPtr &func_graph);
private:
std::map<std::string, int> node_id_map_;
std::map<std::pair<AnfNodePtr, int>, int> node_id_map_;
std::vector<schema::CNodeT *> graph_input_nodes_;
// The first item is FuncGraph which has been exported, the second item is the subgraph index in meta_graph
std::map<FuncGraphPtr, int> fg_subgraph_map_;