!12385 [MSLITE] Fix bug of batchnorm_convert_to_scale_pass"
From: @wang_shaocong Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhanghaibo5
This commit is contained in:
commit
fb4330c878
|
@ -40,6 +40,7 @@ namespace {
|
||||||
constexpr const float EPS = 1e-8;
|
constexpr const float EPS = 1e-8;
|
||||||
constexpr const float EPS_DEFAULT_FLOAT = 1e-8;
|
constexpr const float EPS_DEFAULT_FLOAT = 1e-8;
|
||||||
constexpr const float POW_NUM = 0.5;
|
constexpr const float POW_NUM = 0.5;
|
||||||
|
constexpr uint32_t kQuadrupleNum = 4;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) {
|
STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) {
|
||||||
|
@ -52,6 +53,11 @@ STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto input_index = node->inputIndex.at(0);
|
||||||
|
if (graph->allTensors.at(input_index)->dims.empty()) {
|
||||||
|
MS_LOG(WARNING) << "The shape of input tensor is uncertain.";
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
auto status = GenNewScaleTensor(graph, node);
|
auto status = GenNewScaleTensor(graph, node);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
|
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
|
||||||
|
@ -75,9 +81,13 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std:
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
// after fusion bn must NHWC
|
// after fusion bn must NHWC
|
||||||
scaleParam->axis = -1;
|
|
||||||
bnNode->primitive->value.value = scaleParam.release();
|
|
||||||
auto input0 = bnNode->inputIndex.at(0);
|
auto input0 = bnNode->inputIndex.at(0);
|
||||||
|
if (graph->allTensors.at(input0)->dims.size() == kQuadrupleNum) {
|
||||||
|
scaleParam->axis = -1;
|
||||||
|
} else {
|
||||||
|
scaleParam->axis = 1;
|
||||||
|
}
|
||||||
|
bnNode->primitive->value.value = scaleParam.release();
|
||||||
bnNode->inputIndex.clear();
|
bnNode->inputIndex.clear();
|
||||||
bnNode->inputIndex.push_back(input0);
|
bnNode->inputIndex.push_back(input0);
|
||||||
graph->allTensors.emplace_back(std::move(newScaleWeightTensor));
|
graph->allTensors.emplace_back(std::move(newScaleWeightTensor));
|
||||||
|
|
Loading…
Reference in New Issue