!5015 add mul add fusion pass

Merge pull request !5015 from zhengjun10/master
This commit is contained in:
mindspore-ci-bot 2020-08-25 22:02:14 +08:00 committed by Gitee
commit 2b23de6161
5 changed files with 215 additions and 0 deletions

View File

@ -76,6 +76,9 @@ int ScaleCPUKernel::InitParameter() {
auto scale_tensor = in_tensors_.at(1);
auto scale_shape = scale_tensor->shape();
if (scale_param_->axis_ < 0) {
scale_param_->axis_ = scale_param_->axis_ + in_shape.size();
}
if (scale_shape.size() + scale_param_->axis_ > in_shape.size()) {
MS_LOG(ERROR) << "Scale tensor shape is incorrect.";
return RET_ERROR;

View File

@ -28,6 +28,7 @@
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.h"
#include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
#include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
@ -172,6 +173,17 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
{
Optimizer fusionOptimizer;
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
return status;
}
}
// do quantization
if (fbQuantizer != nullptr) {
status = fbQuantizer->DoQuantize();

View File

@ -2,6 +2,7 @@ add_library(fusion_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/fusion_pattern.cc
${CMAKE_CURRENT_SOURCE_DIR}/fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/matmul_biasadd_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/mul_add_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_fusion_pass.cc

View File

@ -0,0 +1,147 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <unordered_map>
#include <vector>
#include <utility>
#include <memory>
#include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
#include "utils/log_adapter.h"
#include "securec/include/securec.h"
// #include "utils/log_adapter.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
#include "src/common/op_utils.h"
namespace mindspore {
namespace lite {
#define MUL_ADD_MATCH_PATH_LEN 2
#define ADD_OP_BIAS_INDEX 1
#define MUL_OP_BIAS_INDEX 1
#define MUL_OP_INPUT_NUM 2
#define ADD_OP_INPUT_NUM 2
STATUS MulAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); }
STATUS MulAddFusionPass::DefinePattern() {
auto mulOp = std::make_shared<PatternOp>();
mulOp->id = MUL_NAME;
mulOp->types = {schema::PrimitiveType_Mul};
auto baOp = std::make_shared<PatternOp>();
baOp->id = ADD_NAME;
baOp->types = {schema::PrimitiveType_Add};
baOp->left = mulOp;
std::unique_ptr<FusionPattern> fusionPattern(new(std::nothrow) FusionPattern("MulAddFusion"));
if (fusionPattern == nullptr) {
MS_LOG(ERROR) << "new fusionPattern failed";
return RET_ERROR;
}
fusionPattern->AddPatternOp(mulOp);
fusionPattern->AddPatternOp(baOp);
fusionPattern->Finish();
this->patterns.emplace_back(fusionPattern.release());
return RET_OK;
}
STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
if (matchedPath.size() != MUL_ADD_MATCH_PATH_LEN) {
MS_LOG(ERROR) << "Mul-Add-Fusion should have two NodeIndex in matchedPair";
return RET_PARAM_INVALID;
}
auto mulPath = matchedPath[MUL_NAME];
auto addPath = matchedPath[ADD_NAME];
auto &mulNode = graph->nodes.at(mulPath->nodeIdx);
auto &addNode = graph->nodes.at(addPath->nodeIdx);
// can not check shape because there is now shape infer in converter
MS_ASSERT(mulNode != nullptr);
auto mulNodeInputIndex = mulNode->inputIndex;
MS_ASSERT(mulNodeInputIndex.size() == MUL_OP_INPUT_NUM);
MS_ASSERT(graph->allTensors.size() > mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
MS_ASSERT(mulNodeBiasTensor != nullptr);
if (mulNodeBiasTensor->refCount != schema::NodeType_ValueNode) {
// dont fusion, return
return RET_OK;
}
// add node the second tensor is not constant tensor, don't fusion
auto addNodeInputIndex = addNode->inputIndex;
if (addNodeInputIndex.size() != ADD_OP_INPUT_NUM) {
MS_LOG(ERROR) << "add node input tensors number is invalid! "; // baNode->name.c_str());
return RET_ERROR;
}
MS_ASSERT(graph->allTensors.size() > addNodeInputIndex.at(ADD_OP_BIAS_INDEX));
const auto &addNodeBiasTensor = graph->allTensors.at(addNodeInputIndex.at(ADD_OP_BIAS_INDEX));
MS_ASSERT(addNodeBiasTensor != nullptr);
if (addNodeBiasTensor->refCount != schema::NodeType_ValueNode) {
// dont fusion, return
return RET_OK;
}
// convert mul and add to scale
auto status = AddNewScaleNode(graph, mulNode, addNode.get(), addNodeInputIndex.at(ADD_OP_BIAS_INDEX));
if (RET_OK != status) {
MS_LOG(ERROR) << "AddFullConnectionBiasTensor failed, %d"; // status);
return status;
}
return RET_OK;
}
STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_ptr<CNodeT> &mulNode,
CNodeT* addNode, uint32_t addBiasIndex) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(mulNode != nullptr);
MS_ASSERT(addNode != nullptr);
// replace mulNode as scale
mulNode->primitive->value.type = schema::PrimitiveType_Scale;
std::unique_ptr<ScaleT> scaleParam(new ScaleT());
if (scaleParam == nullptr) {
MS_LOG(ERROR) << "new transposeParam failed";
return RET_ERROR;
}
// NHWC
scaleParam->axis = -1;
mulNode->primitive->value.value = scaleParam.release();
mulNode->inputIndex.push_back(addBiasIndex);
if (addNode->primitive->value.AsAdd()->activationType != ActivationType_NO_ACTIVATION) {
// repace addnode as activation
std::unique_ptr<ActivationT> activationParam(new ActivationT());
activationParam->type = addNode->primitive->value.AsAdd()->activationType;
// activationParam->alpha = 0.0;
addNode->primitive->value.type = schema::PrimitiveType_Activation;
addNode->primitive->value.value = activationParam.release();
addNode->inputIndex.pop_back();
return RET_OK;
}
// delete addnode
auto status = IsolateOneWayNode(graph, addNode);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: %zu, node: %zu, error: %d";
// baPath->subGraphIdx, baPath->nodeIdx, status);
return status;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_PREDICT_MUL_ADD_FUSION_PASS_H
#define MINDSPORE_PREDICT_MUL_ADD_FUSION_PASS_H
#include <string>
#include <unordered_map>
#include <memory>
#include <algorithm>
#include <utility>
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
#include "tools/common/graph_util.h"
namespace mindspore {
namespace lite {
constexpr const char *MUL_NAME = "MUL";
constexpr const char *ADD_NAME = "ADD";
class MulAddFusionPass : public FusionPass {
public:
MulAddFusionPass() = default;
~MulAddFusionPass() = default;
STATUS DefinePattern() override;
STATUS DoFusion(MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;
STATUS Run(MetaGraphT *graph) override;
protected:
static STATUS AddNewScaleNode(MetaGraphT *graph, const std::unique_ptr<CNodeT> &mulNode,
CNodeT* addNode, uint32_t addBiasIndex);
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_MUL_ADD_FUSION_PASS_H