add tensor name

This commit is contained in:
yvette 2020-11-27 10:40:39 +08:00
parent 1321483749
commit c9c252550d
5 changed files with 126 additions and 0 deletions

View File

@ -54,6 +54,7 @@ table Tensor {
data: [ubyte];
quantParams: [QuantParam];
quantClusters: [float];
name: string;
}
union PrimitiveType {

View File

@ -34,6 +34,7 @@
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h"
#include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
#include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h"
@ -183,6 +184,17 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
return status;
}
}
// tensor name
{
Optimizer nameOptimizer;
nameOptimizer.AddPass(new (std::nothrow) TensorNamePass());
status = nameOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed";
return status;
}
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -14,6 +14,7 @@ file(GLOB GRAPH_PASS
${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc
)
set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(graph_pass_mid OBJECT ${GRAPH_PASS})

View File

@ -0,0 +1,77 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/common/tensor_util.h"
namespace mindspore::lite {
STATUS TensorNamePass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (int i = 0; i < static_cast<int>(graph->inputIndex.size()); i++) {
auto tensor_id = graph->inputIndex.at(i);
auto &tensor = graph->allTensors.at(tensor_id);
tensor->name = "graph_input-" + std::to_string(i);
}
for (auto &node : graph->nodes) {
if (node == nullptr || node->primitive == nullptr) {
MS_LOG(ERROR) << " node or node->primitive is nullptr";
return RET_ERROR;
}
for (int i = 0; i < static_cast<int>(node->outputIndex.size()); i++) {
auto tensor_id = node->outputIndex.at(i);
auto &tensor = graph->allTensors.at(tensor_id);
if (tensor->name.empty()) {
tensor->name = node->name + "/output-" + std::to_string(i);
}
}
auto type = node->primitive->value.type;
if (type == PrimitiveType_Conv2D || type == PrimitiveType_DeConv2D || type == PrimitiveType_DepthwiseConv2D ||
type == PrimitiveType_DeDepthwiseConv2D || type == PrimitiveType_FullConnection) {
auto input_size = node->inputIndex.size();
if (input_size > 1) {
auto weight_tensor_id = node->inputIndex.at(1);
auto &weight_tensor = graph->allTensors.at(weight_tensor_id);
if (weight_tensor->name.empty()) {
weight_tensor->name = node->name + "/weight";
}
if (input_size > 2) {
auto bias_tensor_id = node->inputIndex.at(2);
auto &bias_tensor = graph->allTensors.at(bias_tensor_id);
if (bias_tensor->name.empty()) {
bias_tensor->name = node->name + "/bias";
}
}
}
} else {
for (int i = 0; i < static_cast<int>(node->inputIndex.size()); i++) {
auto tensor_id = node->inputIndex.at(i);
auto &tensor = graph->allTensors.at(tensor_id);
if (tensor->name.empty()) {
tensor->name = node->name + "/input-" + std::to_string(i);
}
}
}
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,35 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_NAME_TENSOR_PASS_H
#define LITE_NAME_TENSOR_PASS_H
#include <memory>
#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"
namespace mindspore {
namespace lite {
class TensorNamePass : public GraphPass {
public:
TensorNamePass() {}
~TensorNamePass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_NAME_TENSOR_PASS_H