diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 0cbdf18ff45..467b43a67af 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -54,6 +54,7 @@ table Tensor { data: [ubyte]; quantParams: [QuantParam]; quantClusters: [float]; + name: string; } union PrimitiveType { diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index b7dcf60cfd1..ee2349d090d 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -34,6 +34,7 @@ #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" +#include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h" @@ -183,6 +184,17 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { return status; } } + + // tensor name + { + Optimizer nameOptimizer; + nameOptimizer.AddPass(new (std::nothrow) TensorNamePass()); + status = nameOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed"; + return status; + } + } return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt index 060cfe0cb43..4bfd3c5376e 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt @@ -14,6 +14,7 @@ file(GLOB GRAPH_PASS ${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc ) set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc new file mode 100644 index 00000000000..a203b3c5152 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" +#include "tools/converter/converter_context.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "tools/common/tensor_util.h" + +namespace mindspore::lite { +STATUS TensorNamePass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + + for (int i = 0; i < static_cast(graph->inputIndex.size()); i++) { + auto tensor_id = graph->inputIndex.at(i); + auto &tensor = graph->allTensors.at(tensor_id); + tensor->name = "graph_input-" + std::to_string(i); + } + + for (auto &node : graph->nodes) { + if (node == nullptr || node->primitive == nullptr) { + MS_LOG(ERROR) << " node or node->primitive is nullptr"; + return RET_ERROR; + } + + for (int i = 0; i < static_cast(node->outputIndex.size()); i++) { + auto tensor_id = node->outputIndex.at(i); + auto &tensor = graph->allTensors.at(tensor_id); + if (tensor->name.empty()) { + tensor->name = node->name + "/output-" + std::to_string(i); + } + } + + auto type = node->primitive->value.type; + if (type == PrimitiveType_Conv2D || type == PrimitiveType_DeConv2D || type == PrimitiveType_DepthwiseConv2D || + type == PrimitiveType_DeDepthwiseConv2D || type == PrimitiveType_FullConnection) { + auto input_size = node->inputIndex.size(); + if (input_size > 1) { + auto weight_tensor_id = node->inputIndex.at(1); + auto &weight_tensor = graph->allTensors.at(weight_tensor_id); + if (weight_tensor->name.empty()) { + weight_tensor->name = node->name + "/weight"; + } + + if (input_size > 2) { + auto bias_tensor_id = node->inputIndex.at(2); + auto &bias_tensor = graph->allTensors.at(bias_tensor_id); + if (bias_tensor->name.empty()) { + bias_tensor->name = node->name + "/bias"; + } + } + } + } else { + for (int i = 0; i < static_cast(node->inputIndex.size()); i++) { + auto tensor_id = node->inputIndex.at(i); + auto &tensor = graph->allTensors.at(tensor_id); + if (tensor->name.empty()) { + tensor->name = node->name + "/input-" + std::to_string(i); + } + } + } + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.h new file mode 100644 index 00000000000..b564a33b037 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.h @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_NAME_TENSOR_PASS_H +#define LITE_NAME_TENSOR_PASS_H + +#include +#include "tools/converter/optimizer.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +class TensorNamePass : public GraphPass { + public: + TensorNamePass() {} + + ~TensorNamePass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore +#endif // LITE_NAME_TENSOR_PASS_H