converter format pass optimize
This commit is contained in:
parent
9ab4f1237b
commit
1a12ad2615
|
@ -16,8 +16,8 @@ tracking
|
|||
mtk_isface
|
||||
mtk_landmark
|
||||
mtk_pose_tuku
|
||||
mtk_face_recognition_v1
|
||||
mtk_2012_ATLANTA_10class_20190614_v41
|
||||
# mtk_face_recognition_v1
|
||||
# mtk_2012_ATLANTA_10class_20190614_v41
|
||||
mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified
|
||||
detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified
|
||||
hiai_face_detect_rfb
|
||||
|
@ -37,7 +37,7 @@ ml_hardware_pose
|
|||
ml_bank_recog
|
||||
2012_ATLANTA_10class_20190131_v4.0
|
||||
mnet
|
||||
recognition
|
||||
# recognition
|
||||
ml_face_landmark
|
||||
model_hebing_3branch
|
||||
hiai_cv_focusShootOCRModel_07
|
||||
|
@ -48,9 +48,9 @@ hiai_cv_focusShootOCRModel_04
|
|||
hiai_cv_focusShootOCRModel_06
|
||||
hiai_cpu_face_hat
|
||||
hiai_video_seg
|
||||
hiai_semantic_seg
|
||||
# hiai_semantic_seg
|
||||
hiai_human_seg
|
||||
hiai_face_recognition_1
|
||||
# hiai_face_recognition_1
|
||||
hiai_cpu_face_detect
|
||||
hiai_cpu_face_attr
|
||||
hiai_face_attr1
|
||||
|
|
|
@ -27,8 +27,8 @@
|
|||
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
|
||||
|
|
|
@ -7,7 +7,6 @@ add_library(fusion_mid OBJECT
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_fusion_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_transpose_fusion_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc
|
||||
)
|
||||
|
||||
target_link_libraries(fusion_mid securec)
|
||||
|
|
|
@ -1,100 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_PREDICT_BATCHNORM_CONVERT_SCALE_PASS_H
|
||||
#define MINDSPORE_PREDICT_BATCHNORM_CONVERT_SCALE_PASS_H
|
||||
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
struct BNWeightTensors {
|
||||
TensorT *meanTensor = nullptr;
|
||||
TensorT *varianceTensor = nullptr;
|
||||
TensorT *scaleTensor = nullptr;
|
||||
TensorT *biasTensor = nullptr;
|
||||
};
|
||||
class BatchNormConvertScalePass : public FusionPass {
|
||||
public:
|
||||
BatchNormConvertScalePass() = default;
|
||||
|
||||
~BatchNormConvertScalePass() = default;
|
||||
|
||||
STATUS DefinePattern() override;
|
||||
|
||||
STATUS DoFusion(MetaGraphT *graph, const std::string &patternName,
|
||||
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;
|
||||
|
||||
STATUS Run(MetaGraphT *graph) override;
|
||||
|
||||
protected:
|
||||
STATUS GetTransParam(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath);
|
||||
|
||||
// Get and check BNNode weight tensor
|
||||
STATUS GetBnWeightTensors(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath, BNWeightTensors* bnWeightTensors);
|
||||
|
||||
STATUS GetBnEpsilon(MetaGraphT *graph);
|
||||
|
||||
STATUS FindNodes(MetaGraphT *graph, const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath);
|
||||
|
||||
STATUS GenNewScaleTensor(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath);
|
||||
|
||||
STATUS ConvertBNToScale(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath);
|
||||
|
||||
CNodeT *inputNode = nullptr;
|
||||
CNodeT *bnNode = nullptr;
|
||||
|
||||
std::string inputOpName = "Input";
|
||||
std::string bnOpName = "BatchNorm";
|
||||
std::string bnPatternName = "BnToScaleFusion";
|
||||
uint32_t bnChannel = 0;
|
||||
float eps = 0;
|
||||
TensorT *bnMeanTensor = nullptr;
|
||||
float *transScale = nullptr;
|
||||
float *transBias = nullptr;
|
||||
std::unique_ptr<TensorT> newScaleWeightTensor = nullptr;
|
||||
std::unique_ptr<TensorT> newScaleBiasTensor = nullptr;
|
||||
|
||||
OpDefCopyer ScaleOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> {
|
||||
std::unique_ptr<CNodeT> newOpDef(new(std::nothrow) CNodeT);
|
||||
if (newOpDef == nullptr) {
|
||||
MS_LOG(ERROR) << "new OpDefT failed";
|
||||
return nullptr;
|
||||
}
|
||||
newOpDef->name = inOpDef->name;
|
||||
newOpDef->quantType = inOpDef->quantType;
|
||||
newOpDef->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
newOpDef->primitive->value.type = schema::PrimitiveType_Scale;
|
||||
auto scaleParam = new(std::nothrow) ScaleT;
|
||||
if (scaleParam == nullptr) {
|
||||
MS_LOG(ERROR) << "new scaleParam failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto inParam = inOpDef->primitive->value.AsScale();
|
||||
MS_ASSERT(inParam != nullptr);
|
||||
scaleParam->axis = inParam->axis;
|
||||
newOpDef->primitive->value.value = scaleParam;
|
||||
return std::move(newOpDef);
|
||||
};
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_PREDICT_BATCHNORM_CONVERT_SCALE_PASS_H
|
|
@ -8,4 +8,5 @@ add_library(graph_pass_mid OBJECT
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/weight_format_transform_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unused_node_remove_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h"
|
||||
#include <cfloat>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -44,123 +44,56 @@ constexpr const float EPS_DEFAULT_FLOAT = 1e-8;
|
|||
constexpr const float POW_NUM = 0.5;
|
||||
constexpr const int32_t NCHW_DIM_C = 1;
|
||||
}
|
||||
STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); }
|
||||
|
||||
STATUS BatchNormConvertScalePass::DefinePattern() {
|
||||
// with preNode
|
||||
{
|
||||
auto inputOp = std::make_shared<PatternOp>();
|
||||
inputOp->id = inputOpName;
|
||||
inputOp->types = {schema::PrimitiveType_NONE};
|
||||
inputOp->isPlaceHold = true;
|
||||
|
||||
auto bnOp = std::make_shared<PatternOp>();
|
||||
bnOp->id = bnOpName;
|
||||
bnOp->types = {schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_BatchNorm};
|
||||
bnOp->left = inputOp;
|
||||
|
||||
std::unique_ptr<FusionPattern> fusionPattern(new(std::nothrow) FusionPattern(bnPatternName));
|
||||
if (fusionPattern == nullptr) {
|
||||
MS_LOG(ERROR) << "new fusionPattern failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
fusionPattern->AddPatternOp(inputOp);
|
||||
fusionPattern->AddPatternOp(bnOp);
|
||||
fusionPattern->Finish();
|
||||
|
||||
this->patterns.emplace_back(fusionPattern.release());
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS BatchNormConvertScalePass::DoFusion(MetaGraphT *graph, const std::string &patternName,
|
||||
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
|
||||
STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
if (patternName != bnPatternName) {
|
||||
MS_LOG(ERROR) << "BatchNormConvertScale-Fusion match failed";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto status = FindNodes(graph, matchedPath);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "FindNodes failed: " << status;
|
||||
return status;
|
||||
}
|
||||
auto type = bnNode->primitive->value.type;
|
||||
if (type != schema::PrimitiveType_FusedBatchNorm && type != schema::PrimitiveType_BatchNorm) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto bnPath = matchedPath.at(bnOpName);
|
||||
status = GenNewScaleTensor(graph, bnPath);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
|
||||
delete[] transScale;
|
||||
delete[] transBias;
|
||||
transScale = nullptr;
|
||||
transBias = nullptr;
|
||||
return status;
|
||||
}
|
||||
|
||||
status = ConvertBNToScale(graph, bnPath);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
|
||||
delete[] transScale;
|
||||
delete[] transBias;
|
||||
transScale = nullptr;
|
||||
transBias = nullptr;
|
||||
return status;
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto &node = *iter;
|
||||
auto type = node->primitive->value.type;
|
||||
if (type != schema::PrimitiveType_FusedBatchNorm && type != schema::PrimitiveType_BatchNorm) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto status = GenNewScaleTensor(graph, node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
|
||||
return status;
|
||||
}
|
||||
status = ConvertBNToScale(graph, node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
|
||||
return status;
|
||||
}
|
||||
}
|
||||
delete[] transScale;
|
||||
delete[] transBias;
|
||||
transScale = nullptr;
|
||||
transBias = nullptr;
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath) {
|
||||
auto scaleNode = std::unique_ptr<CNodeT>(new(std::nothrow) CNodeT);
|
||||
if (scaleNode == nullptr) {
|
||||
MS_LOG(ERROR) << "new TransNode failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
scaleNode->name = bnNode->name;
|
||||
scaleNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (scaleNode->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
scaleNode->primitive->value.type = schema::PrimitiveType_Scale;
|
||||
STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(bnNode != nullptr);
|
||||
bnNode->primitive->value.type = schema::PrimitiveType_Scale;
|
||||
std::unique_ptr<ScaleT> scaleParam(new ScaleT());
|
||||
if (scaleParam == nullptr) {
|
||||
MS_LOG(ERROR) << "new transposeParam failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
scaleParam->axis = NCHW_DIM_C;
|
||||
scaleNode->primitive->value.value = scaleParam.release();
|
||||
auto scaleIter = graph->nodes.begin() + bnPath->nodeIdx;
|
||||
STATUS errorCode = RET_OK;
|
||||
scaleIter =
|
||||
InsertNode(graph, scaleIter, kBefore, 0, std::move(scaleNode), &errorCode, ScaleOpCopyer);
|
||||
if (errorCode != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNode failed: %d"; // errorCode);
|
||||
return errorCode;
|
||||
}
|
||||
auto &newScaleNode = *(scaleIter - 1);
|
||||
bnNode->primitive->value.value = scaleParam.release();
|
||||
auto input0 = bnNode->inputIndex.at(0);
|
||||
bnNode->inputIndex.clear();
|
||||
bnNode->inputIndex.push_back(input0);
|
||||
graph->allTensors.emplace_back(std::move(newScaleWeightTensor));
|
||||
auto weightTensorIdx = graph->allTensors.size() - 1;
|
||||
graph->allTensors.emplace_back(std::move(newScaleBiasTensor));
|
||||
auto biasTensorIdx = graph->allTensors.size() - 1;
|
||||
newScaleNode->inputIndex.push_back(weightTensorIdx);
|
||||
newScaleNode->inputIndex.push_back(biasTensorIdx);
|
||||
// delete bn node
|
||||
auto status = IsolateOneWayNode(graph, bnPath->nodeIdx + 1, true);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "IsolateOneWayNode " << bnNode->name.c_str() << " failed, error: " << status;
|
||||
return status;
|
||||
}
|
||||
bnNode->inputIndex.push_back(weightTensorIdx);
|
||||
bnNode->inputIndex.push_back(biasTensorIdx);
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS BatchNormConvertScalePass::GenNewScaleTensor(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath) {
|
||||
STATUS BatchNormConvertScalePass::GenNewScaleTensor(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
GetTransParam(graph, bnPath);
|
||||
MS_ASSERT(bnNode != nullptr);
|
||||
GetTransParam(graph, bnNode);
|
||||
newScaleWeightTensor = std::unique_ptr<TensorT>(new(std::nothrow) TensorT);
|
||||
if (newScaleWeightTensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new weightTensor failed";
|
||||
|
@ -175,8 +108,11 @@ STATUS BatchNormConvertScalePass::GenNewScaleTensor(MetaGraphT *graph, const std
|
|||
auto ret = memcpy_s(newScaleWeightTensor->data.data(), weightShapeSize * sizeof(float), transScale,
|
||||
weightShapeSize * sizeof(float));
|
||||
if (ret != RET_OK) {
|
||||
delete transScale;
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
delete[] transScale;
|
||||
delete[] transBias;
|
||||
transScale = nullptr;
|
||||
transBias = nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
|
@ -195,39 +131,25 @@ STATUS BatchNormConvertScalePass::GenNewScaleTensor(MetaGraphT *graph, const std
|
|||
ret = memcpy_s(newScaleBiasTensor->data.data(), weightShapeSize * sizeof(float), transBias,
|
||||
weightShapeSize * sizeof(float));
|
||||
if (ret != RET_OK) {
|
||||
delete transBias;
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
delete[] transScale;
|
||||
delete[] transBias;
|
||||
transScale = nullptr;
|
||||
transBias = nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
delete[] transScale;
|
||||
delete[] transBias;
|
||||
transScale = nullptr;
|
||||
transBias = nullptr;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS BatchNormConvertScalePass::FindNodes(MetaGraphT *graph,
|
||||
const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
|
||||
STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto inputPath = matchedPath.at(inputOpName);
|
||||
auto bnPath = matchedPath.at(bnOpName);
|
||||
MS_ASSERT(inputPath != nullptr);
|
||||
MS_ASSERT(bnPath != nullptr);
|
||||
if (inputPath->subGraphIdx != bnPath->subGraphIdx) {
|
||||
MS_LOG(ERROR) << "matched nodes should from same subGraph";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(graph->nodes.size() > inputPath->nodeIdx);
|
||||
MS_ASSERT(graph->nodes.size() > bnPath->nodeIdx);
|
||||
inputNode = graph->nodes.at(inputPath->nodeIdx).get();
|
||||
bnNode = graph->nodes.at(bnPath->nodeIdx).get();
|
||||
MS_ASSERT(inputNode != nullptr);
|
||||
MS_ASSERT(bnNode != nullptr);
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(bnPath != nullptr);
|
||||
|
||||
BNWeightTensors bnWeightTensors;
|
||||
|
||||
auto status = GetBnWeightTensors(graph, bnPath, &bnWeightTensors);
|
||||
auto status = GetBnWeightTensors(graph, &bnWeightTensors, bnNode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "GetBnWeightTensors error";
|
||||
return status;
|
||||
|
@ -241,7 +163,7 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::sh
|
|||
auto *varianceData = reinterpret_cast<float *>(varianceTensor->data.data());
|
||||
|
||||
eps = EPS_DEFAULT_FLOAT;
|
||||
status = GetBnEpsilon(graph);
|
||||
status = GetBnEpsilon(bnNode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "GetBnEpsilon failed";
|
||||
return status;
|
||||
|
@ -298,12 +220,11 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::sh
|
|||
// bias --1
|
||||
// estimated_mean --2
|
||||
// estimated_variance --3
|
||||
STATUS BatchNormConvertScalePass::GetBnWeightTensors(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath,
|
||||
BNWeightTensors* bnWeightTensors) {
|
||||
if (graph == nullptr || bnPath == nullptr) {
|
||||
MS_LOG(ERROR) << "null pointer dereferencing.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
STATUS BatchNormConvertScalePass::GetBnWeightTensors(MetaGraphT *graph, BNWeightTensors *bnWeightTensors,
|
||||
const std::unique_ptr<CNodeT> &bnNode) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(bnNode != nullptr);
|
||||
MS_ASSERT(bnWeightTensors != nullptr);
|
||||
MS_ASSERT(graph->allTensors.size() > bnNode->inputIndex.at(1));
|
||||
auto bnWeightTensorIdxes = bnNode->inputIndex;
|
||||
bnWeightTensorIdxes.erase(bnWeightTensorIdxes.begin());
|
||||
|
@ -357,15 +278,9 @@ STATUS BatchNormConvertScalePass::GetBnWeightTensors(MetaGraphT *graph, const st
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS BatchNormConvertScalePass::GetBnEpsilon(MetaGraphT *graph) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "null pointer dereferencing.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (bnNode == nullptr) {
|
||||
MS_LOG(ERROR) << "null pointer dereferencing.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
STATUS BatchNormConvertScalePass::GetBnEpsilon(const std::unique_ptr<CNodeT> &bnNode) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(bnNode != nullptr);
|
||||
if (bnNode->primitive->value.type == schema::PrimitiveType_FusedBatchNorm) {
|
||||
eps = bnNode->primitive->value.AsFusedBatchNorm()->epsilon;
|
||||
} else if (bnNode->primitive->value.type == schema::PrimitiveType_BatchNorm) {
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_PREDICT_BATCHNORM_CONVERT_SCALE_PASS_H
|
||||
#define MINDSPORE_PREDICT_BATCHNORM_CONVERT_SCALE_PASS_H
|
||||
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/converter/optimizer.h"
|
||||
|
||||
using mindspore::schema::TensorT;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
struct BNWeightTensors {
|
||||
schema::TensorT *meanTensor = nullptr;
|
||||
TensorT *varianceTensor = nullptr;
|
||||
TensorT *scaleTensor = nullptr;
|
||||
TensorT *biasTensor = nullptr;
|
||||
};
|
||||
class BatchNormConvertScalePass : public GraphPass {
|
||||
public:
|
||||
BatchNormConvertScalePass() = default;
|
||||
|
||||
~BatchNormConvertScalePass() = default;
|
||||
|
||||
STATUS Run(MetaGraphT *graph) override;
|
||||
|
||||
protected:
|
||||
STATUS GetTransParam(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode);
|
||||
|
||||
// Get and check BNNode weight tensor
|
||||
STATUS GetBnWeightTensors(MetaGraphT *graph, BNWeightTensors *bnWeightTensors, const std::unique_ptr<CNodeT> &bnNode);
|
||||
|
||||
STATUS GetBnEpsilon(const std::unique_ptr<CNodeT> &bnNode);
|
||||
|
||||
STATUS GenNewScaleTensor(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode);
|
||||
|
||||
STATUS ConvertBNToScale(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode);
|
||||
|
||||
uint32_t bnChannel = 0;
|
||||
float eps = 0;
|
||||
TensorT *bnMeanTensor = nullptr;
|
||||
float *transScale = nullptr;
|
||||
float *transBias = nullptr;
|
||||
std::unique_ptr<TensorT> newScaleWeightTensor = nullptr;
|
||||
std::unique_ptr<TensorT> newScaleBiasTensor = nullptr;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_PREDICT_BATCHNORM_CONVERT_SCALE_PASS_H
|
|
@ -121,7 +121,8 @@ STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) {
|
|||
MS_ASSERT(graph != nullptr);
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto &node = *iter;
|
||||
if (node->primitive->value.type != PrimitiveType_Eltwise) {
|
||||
auto type = node->primitive->value.type;
|
||||
if (type != PrimitiveType_Eltwise && type != PrimitiveType_Activation) {
|
||||
continue;
|
||||
}
|
||||
auto node_name = node->name;
|
||||
|
|
|
@ -295,6 +295,9 @@ ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, in
|
|||
MS_ASSERT(param_value != nullptr);
|
||||
param_value->set_tensor_addr(bias_data);
|
||||
param_value->set_tensor_size(kernel_num * sizeof(float) / sizeof(uint8_t));
|
||||
param_value->set_format(weight_tensor->format());
|
||||
param_value->set_tensor_type(weight_tensor->tensor_type());
|
||||
param_value->set_tensor_shape(shape);
|
||||
bias_parameter->set_default_param(param_value);
|
||||
return bias_parameter;
|
||||
}
|
||||
|
|
|
@ -86,6 +86,7 @@ const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *ten
|
|||
MS_ASSERT(param_value != nullptr);
|
||||
param_value->set_tensor_shape(shape);
|
||||
param_value->set_tensor_type(type_id);
|
||||
param_value->set_format(tensor->GetFormat());
|
||||
if (tensor->Data() != nullptr) {
|
||||
auto size = tensor->ElementsNum();
|
||||
auto tensor_data = new (std::nothrow) float[size];
|
||||
|
|
|
@ -51,13 +51,13 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec);
|
||||
MS_ASSERT(act_primitivec != nullptr);
|
||||
if (act_primitivec->GetType() != activation_type) {
|
||||
return node;
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr pre_node = act_node->input(1);
|
||||
CheckIfAnfNodeIsNull(pre_node);
|
||||
if (pre_node != nullptr && pre_node->isa<CNode>()) {
|
||||
if (IsMultiOutputTensors(func_graph, pre_node)) {
|
||||
return node;
|
||||
return nullptr;
|
||||
}
|
||||
auto conv_node = pre_node->cast<CNodePtr>();
|
||||
auto node_type = GetCNodeType(conv_node);
|
||||
|
@ -80,9 +80,9 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
return pre_node;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "conv activation pass match only conv2d or depthwise_conv2d ";
|
||||
MS_LOG(ERROR) << "conv activation pass match only conv2d or depthwise_conv2d ";
|
||||
}
|
||||
}
|
||||
return node;
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
|
|
|
@ -179,7 +179,8 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
|
|||
MS_ASSERT(primc != nullptr);
|
||||
primc->SetHasBias(true);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported opType, " << type;
|
||||
MS_LOG(ERROR) << "Unsupported opType, " << type;
|
||||
return nullptr;
|
||||
}
|
||||
return conv_node;
|
||||
}
|
||||
|
|
|
@ -85,12 +85,13 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
auto trans_scale = new (std::nothrow) float[kernel_nums];
|
||||
if (trans_scale == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_data is nullptr";
|
||||
delete[] trans_scale;
|
||||
return nullptr;
|
||||
}
|
||||
auto trans_bias = new (std::nothrow) float[kernel_nums];
|
||||
if (trans_bias == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_data is nullptr";
|
||||
delete trans_scale;
|
||||
delete[] trans_bias;
|
||||
return nullptr;
|
||||
}
|
||||
GenTransParam(transform_node, kernel_nums, trans_scale, trans_bias);
|
||||
|
@ -111,7 +112,8 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
MS_ASSERT(primc != nullptr);
|
||||
primc->SetHasBias(true);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported opType, " << type;
|
||||
MS_LOG(ERROR) << "Unsupported opType, " << type;
|
||||
return nullptr;
|
||||
}
|
||||
pre_node->set_abstract(abstr);
|
||||
const auto &prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(transform_node->input(0));
|
||||
|
@ -187,6 +189,7 @@ const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph,
|
|||
bias_data = new (std::nothrow) float[kernel_num];
|
||||
if (bias_data == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_data is nullptr";
|
||||
delete[] bias_data;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue