From fa2e454af3617c10433ada517b992cfda4d9a369 Mon Sep 17 00:00:00 2001 From: wang_shaocong Date: Thu, 18 Feb 2021 15:22:12 +0800 Subject: [PATCH] [MSLITE] Fix bug of batchnorm_convert_scale_pass --- .../graph/batchnorm_convert_scale_pass.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc index 1a7ead0d785..d052baf8b58 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc @@ -40,6 +40,7 @@ namespace { constexpr const float EPS = 1e-8; constexpr const float EPS_DEFAULT_FLOAT = 1e-8; constexpr const float POW_NUM = 0.5; +constexpr uint32_t kQuadrupleNum = 4; } // namespace STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) { @@ -52,6 +53,11 @@ STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) { continue; } + auto input_index = node->inputIndex.at(0); + if (graph->allTensors.at(input_index)->dims.empty()) { + MS_LOG(WARNING) << "The shape of input tensor is uncertain."; + return RET_OK; + } auto status = GenNewScaleTensor(graph, node); if (status != RET_OK) { MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status; @@ -75,9 +81,13 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std: return RET_ERROR; } // after fusion bn must NHWC - scaleParam->axis = -1; - bnNode->primitive->value.value = scaleParam.release(); auto input0 = bnNode->inputIndex.at(0); + if (graph->allTensors.at(input0)->dims.size() == kQuadrupleNum) { + scaleParam->axis = -1; + } else { + scaleParam->axis = 1; + } + bnNode->primitive->value.value = scaleParam.release(); bnNode->inputIndex.clear(); bnNode->inputIndex.push_back(input0); graph->allTensors.emplace_back(std::move(newScaleWeightTensor));