!28894 add fc fusion with add

Merge pull request !28894 from wangyanling/fcaddfusion
This commit is contained in:
i-robot 2022-01-18 01:17:59 +00:00 committed by Gitee
commit 72d0d70a6b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 261 additions and 21 deletions

View File

@ -59,6 +59,7 @@
#include "tools/optimizer/fusion/scale_activation_fusion.h"
#include "tools/optimizer/fusion/scale_scale_fusion.h"
#include "tools/optimizer/fusion/fullconnected_fusion.h"
#include "tools/optimizer/fusion/fullconnected_add_fusion.h"
#include "tools/optimizer/fusion/add_concat_activation_fusion.h"
#include "tools/optimizer/fusion/matmul_activation_fusion.h"
#include "tools/optimizer/fusion/activation_fusion.h"
@ -217,6 +218,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
fusion_pm->AddPass(std::make_shared<opt::ScaleActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::ScaleScaleFusion>());
fusion_pm->AddPass(std::make_shared<opt::FullConnectedFusion>());
fusion_pm->AddPass(std::make_shared<opt::FullconnectedAddFusion>());
fusion_pm->AddPass(std::make_shared<opt::TensorDotFusion>());
fusion_pm->AddPass(std::make_shared<opt::MatMulActivationFusion>());
optimizer->AddPassManager(fusion_pm);

View File

@ -1126,5 +1126,24 @@ int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, boo
*infer_succ = infer_infos[item_index];
return RET_OK;
}
bool CheckAndGetCnodeIndex(const CNodePtr &cnode, size_t *index, const PrimitivePtr &primitive_type) {
MS_CHECK_TRUE_RET(cnode != nullptr, false);
MS_CHECK_TRUE_RET(index != nullptr, false);
if (cnode->size() != kInputSizeThree) {
return false;
}
size_t dst_index = 0;
for (size_t i = 1; i < cnode->size(); ++i) {
if (CheckPrimitiveType(cnode->input(i), primitive_type)) {
dst_index = i;
break;
}
}
if (dst_index == 0) {
return false;
}
*index = dst_index;
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -141,6 +141,8 @@ std::pair<CNodePtr, int> GetRealCertainVarInput(const CNodePtr &cnode, size_t in
int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ);
bool CheckAndGetCnodeIndex(const CNodePtr &cnode, size_t *index, const PrimitivePtr &primitive_type);
template <const PrimitivePtr *prim = nullptr>
inline bool IsSpecifiedNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {

View File

@ -0,0 +1,197 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/fullconnected_add_fusion.h"
#include <vector>
#include <memory>
#include "ops/fusion/add_fusion.h"
#include "ops/fusion/full_connection.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace opt {
namespace {
bool IsPrimitiveProper(const CNodePtr &add_cnode, const CNodePtr &fc_cnode, int index) {
auto add_primc = GetValueNode<PrimitiveCPtr>(add_cnode->input(0));
MS_CHECK_TRUE_RET(add_primc != nullptr, false);
if (IsQuantParameterNode(add_primc)) {
MS_LOG(INFO) << add_cnode->fullname_with_scope() << " is quant node";
return false;
}
auto add_param_node = add_cnode->input(kInputSizeThree - index);
if (!utils::isa<ValueNode>(add_param_node) &&
(!utils::isa<Parameter>(add_param_node) || !add_param_node->cast<ParameterPtr>()->default_param())) {
return false;
}
auto abstract = add_param_node->abstract();
MS_CHECK_TRUE_RET(abstract != nullptr, false);
std::vector<int64_t> bias_shape;
if (FetchShapeFromAbstract(abstract, &bias_shape) != lite::RET_OK) {
MS_LOG(ERROR) << "Fetch shape from abstract failed.";
return false;
}
if (bias_shape.size() > DIMENSION_1D) {
MS_LOG(INFO) << "only support bias with shape size of 1.";
return false;
}
if (fc_cnode->size() > kInputSizeThree) {
auto fc_bias_node = fc_cnode->input(kInputIndexThree);
if (!utils::isa<ValueNode>(fc_bias_node) &&
(!utils::isa<Parameter>(fc_bias_node) || !fc_bias_node->cast<ParameterPtr>()->default_param())) {
MS_LOG(INFO) << fc_cnode->fullname_with_scope() << "'s bias is not parameter";
return false;
}
}
auto fc_primc = GetValueNode<std::shared_ptr<ops::FullConnection>>(fc_cnode->input(0));
MS_CHECK_TRUE_RET(fc_primc != nullptr, false);
if (fc_primc->GetAttr(ops::kActivationType) != nullptr &&
fc_primc->get_activation_type() != ActivationType::NO_ACTIVATION) {
MS_LOG(INFO) << fc_cnode->fullname_with_scope() << " has activation attr";
return false;
}
if (IsQuantParameterNode(fc_primc)) {
MS_LOG(INFO) << fc_cnode->fullname_with_scope() << "is quant node";
return false;
}
return true;
}
int CalNewCnodeBias(const AnfNodePtr &add_weight_node, const CNodePtr &fc_cnode) {
MS_CHECK_TRUE_RET(add_weight_node != nullptr, RET_ERROR);
MS_CHECK_TRUE_RET(fc_cnode != nullptr, RET_ERROR);
auto fc_bias_node = fc_cnode->input(kInputIndexThree);
MS_CHECK_TRUE_RET(fc_bias_node != nullptr, RET_ERROR);
std::shared_ptr<tensor::Tensor> fc_bias_tensor = GetTensorInfo(fc_bias_node);
MS_CHECK_TRUE_RET(fc_bias_tensor != nullptr, RET_ERROR);
if (fc_bias_tensor->data_type() != kNumberTypeFloat32) {
MS_LOG(INFO) << "only support float32 data type";
return RET_ERROR;
}
std::vector<int64_t> fc_bias_shape = fc_bias_tensor->shape();
auto fc_bias_data = reinterpret_cast<float *>(fc_bias_tensor->data_c());
MS_CHECK_TRUE_RET(fc_bias_data != nullptr, RET_ERROR);
std::shared_ptr<tensor::Tensor> add_weight_tensor = GetTensorInfo(add_weight_node);
MS_CHECK_TRUE_RET(add_weight_tensor != nullptr, RET_ERROR);
if (add_weight_tensor->data_type() != kNumberTypeFloat32) {
MS_LOG(INFO) << "only support float32 data type";
return RET_ERROR;
}
std::vector<int64_t> add_weight_shape = add_weight_tensor->shape();
MS_CHECK_TRUE_RET(fc_bias_shape == add_weight_shape, RET_ERROR);
auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
MS_CHECK_TRUE_RET(add_weight_data != nullptr, RET_ERROR);
for (int64_t i = 0; i < fc_bias_shape[0]; ++i) {
fc_bias_data[i] += add_weight_data[i];
}
return RET_OK;
}
} // namespace
VectorRef FullconnectedAddFusion::DefineFcAddFusionPattern() const {
auto is_fc1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimFullConnection>);
MS_CHECK_TRUE_RET(is_fc1 != nullptr, {});
auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
MS_CHECK_TRUE_RET(is_add != nullptr, {});
auto is_seq_var = std::make_shared<SeqVar>();
MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
return VectorRef({is_add, is_fc1, is_seq_var});
}
VectorRef FullconnectedAddFusion::DefineFcBiasAddPattern() const {
auto is_fc1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimFullConnection>);
MS_CHECK_TRUE_RET(is_fc1 != nullptr, {});
auto is_bias_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBiasAdd>);
MS_CHECK_TRUE_RET(is_bias_add != nullptr, {});
auto is_seq_var = std::make_shared<SeqVar>();
MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
return VectorRef({is_bias_add, is_fc1, is_seq_var});
}
std::unordered_map<std::string, VectorRef> FullconnectedAddFusion::DefinePatterns() const {
std::unordered_map<std::string, VectorRef> patterns;
patterns["FcAddFusionPatternName"] = DefineFcAddFusionPattern();
patterns["FcBiasAddPatternName"] = DefineFcBiasAddPattern();
return patterns;
}
AnfNodePtr FullconnectedAddFusion::Process(const std::string &pattern_name, const FuncGraphPtr &func_graph,
const AnfNodePtr &node, const EquivPtr &equiv) const {
if (func_graph == nullptr || node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
auto add_cnode = node->cast<CNodePtr>();
MS_CHECK_TRUE_RET(add_cnode != nullptr, nullptr);
if (IsMarkedTrainOp(add_cnode)) {
return nullptr;
}
if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) {
return nullptr;
}
size_t index = 0;
if (!CheckAndGetCnodeIndex(add_cnode, &index, prim::kPrimFullConnection)) {
return nullptr;
}
auto fc_cnode = add_cnode->input(index)->cast<CNodePtr>();
MS_ASSERT(fc_cnode != nullptr);
if (IsMarkedTrainOp(fc_cnode)) {
return nullptr;
}
if (IsMultiOutputTensors(func_graph, fc_cnode)) {
return nullptr;
}
if (!IsPrimitiveProper(add_cnode, fc_cnode, index)) {
return nullptr;
}
auto manager = func_graph->manager();
auto add_param_node = add_cnode->input(kInputSizeThree - index);
MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
if (fc_cnode->size() == kInputSizeThree) {
manager->AddEdge(fc_cnode, add_param_node);
} else if (fc_cnode->size() == kInputSizeFour) {
if (CalNewCnodeBias(add_param_node, fc_cnode) != RET_OK) {
MS_LOG(INFO) << add_cnode->fullname_with_scope() << " failed to fusion with " << fc_cnode->fullname_with_scope();
return nullptr;
}
}
if (CheckPrimitiveType(node, prim::kPrimAddFusion)) {
auto add_primc = GetValueNode<std::shared_ptr<ops::AddFusion>>(add_cnode->input(0));
MS_CHECK_TRUE_RET(add_primc != nullptr, nullptr);
if (add_primc->GetAttr(ops::kActivationType) != nullptr &&
add_primc->get_activation_type() != ActivationType::NO_ACTIVATION) {
auto fc_primc = GetValueNode<std::shared_ptr<ops::FullConnection>>(fc_cnode->input(0));
MS_CHECK_TRUE_RET(fc_primc != nullptr, nullptr);
fc_primc->set_activation_type(add_primc->get_activation_type());
}
}
fc_cnode->set_fullname_with_scope(node->fullname_with_scope());
(void)manager->Replace(node, fc_cnode);
return nullptr;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FULLCONNECTED_ADD_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FULLCONNECTED_ADD_FUSION_H_
#include <string>
#include <unordered_map>
#include "backend/optimizer/common/optimizer.h"
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
namespace mindspore::opt {
class FullconnectedAddFusion : public MultiplePatternProcessPass {
public:
explicit FullconnectedAddFusion(const std::string &name = "FullconnectedAddFusion", bool multigraph = true)
: MultiplePatternProcessPass(name, multigraph) {}
~FullconnectedAddFusion() override = default;
private:
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
VectorRef DefineFcAddFusionPattern() const;
VectorRef DefineFcBiasAddPattern() const;
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &,
const EquivPtr &) const override;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FULLCONNECTED_ADD_FUSION_H_

View File

@ -19,7 +19,6 @@
#include <vector>
#include "tools/common/tensor_util.h"
#include "ops/fusion/full_connection.h"
#include "ops/fusion/conv2d_fusion.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/quant_param_holder.h"
#include "nnacl/op_base.h"

View File

@ -25,25 +25,6 @@
namespace mindspore {
namespace opt {
namespace {
bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) {
MS_ASSERT(cnode != nullptr && index != nullptr);
if (cnode->size() != kInputSizeThree) {
return false;
}
size_t matmul_index = 0;
for (size_t i = 1; i < cnode->size(); ++i) {
if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMulFusion)) {
matmul_index = i;
break;
}
}
if (matmul_index == 0) {
return false;
}
*index = matmul_index;
return true;
}
bool IsPrimitiveProper(const CNodePtr &add_cnode, const CNodePtr &matmul_cnode, int index) {
auto add_primc = GetValueNode<PrimitiveCPtr>(add_cnode->input(0));
MS_CHECK_TRUE_RET(add_primc != nullptr, false);
@ -169,7 +150,7 @@ AnfNodePtr MatMulAddFusion::Process(const std::string &pattern_name, const FuncG
}
size_t index = 0;
if (!CheckAndGetMatMulIndex(add_cnode, &index)) {
if (!CheckAndGetCnodeIndex(add_cnode, &index, prim::kPrimMatMulFusion)) {
return nullptr;
}
auto matmul_cnode = add_cnode->input(index)->cast<CNodePtr>();