forked from mindspore-Ecosystem/mindspore
!28894 add fc fusion with add
Merge pull request !28894 from wangyanling/fcaddfusion
This commit is contained in:
commit
72d0d70a6b
|
@ -59,6 +59,7 @@
|
||||||
#include "tools/optimizer/fusion/scale_activation_fusion.h"
|
#include "tools/optimizer/fusion/scale_activation_fusion.h"
|
||||||
#include "tools/optimizer/fusion/scale_scale_fusion.h"
|
#include "tools/optimizer/fusion/scale_scale_fusion.h"
|
||||||
#include "tools/optimizer/fusion/fullconnected_fusion.h"
|
#include "tools/optimizer/fusion/fullconnected_fusion.h"
|
||||||
|
#include "tools/optimizer/fusion/fullconnected_add_fusion.h"
|
||||||
#include "tools/optimizer/fusion/add_concat_activation_fusion.h"
|
#include "tools/optimizer/fusion/add_concat_activation_fusion.h"
|
||||||
#include "tools/optimizer/fusion/matmul_activation_fusion.h"
|
#include "tools/optimizer/fusion/matmul_activation_fusion.h"
|
||||||
#include "tools/optimizer/fusion/activation_fusion.h"
|
#include "tools/optimizer/fusion/activation_fusion.h"
|
||||||
|
@ -217,6 +218,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
||||||
fusion_pm->AddPass(std::make_shared<opt::ScaleActivationFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::ScaleActivationFusion>());
|
||||||
fusion_pm->AddPass(std::make_shared<opt::ScaleScaleFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::ScaleScaleFusion>());
|
||||||
fusion_pm->AddPass(std::make_shared<opt::FullConnectedFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::FullConnectedFusion>());
|
||||||
|
fusion_pm->AddPass(std::make_shared<opt::FullconnectedAddFusion>());
|
||||||
fusion_pm->AddPass(std::make_shared<opt::TensorDotFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::TensorDotFusion>());
|
||||||
fusion_pm->AddPass(std::make_shared<opt::MatMulActivationFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::MatMulActivationFusion>());
|
||||||
optimizer->AddPassManager(fusion_pm);
|
optimizer->AddPassManager(fusion_pm);
|
||||||
|
|
|
@ -1126,5 +1126,24 @@ int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, boo
|
||||||
*infer_succ = infer_infos[item_index];
|
*infer_succ = infer_infos[item_index];
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
bool CheckAndGetCnodeIndex(const CNodePtr &cnode, size_t *index, const PrimitivePtr &primitive_type) {
|
||||||
|
MS_CHECK_TRUE_RET(cnode != nullptr, false);
|
||||||
|
MS_CHECK_TRUE_RET(index != nullptr, false);
|
||||||
|
if (cnode->size() != kInputSizeThree) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t dst_index = 0;
|
||||||
|
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||||
|
if (CheckPrimitiveType(cnode->input(i), primitive_type)) {
|
||||||
|
dst_index = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (dst_index == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*index = dst_index;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -141,6 +141,8 @@ std::pair<CNodePtr, int> GetRealCertainVarInput(const CNodePtr &cnode, size_t in
|
||||||
|
|
||||||
int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ);
|
int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ);
|
||||||
|
|
||||||
|
bool CheckAndGetCnodeIndex(const CNodePtr &cnode, size_t *index, const PrimitivePtr &primitive_type);
|
||||||
|
|
||||||
template <const PrimitivePtr *prim = nullptr>
|
template <const PrimitivePtr *prim = nullptr>
|
||||||
inline bool IsSpecifiedNode(const BaseRef &n) {
|
inline bool IsSpecifiedNode(const BaseRef &n) {
|
||||||
if (utils::isa<AnfNodePtr>(n)) {
|
if (utils::isa<AnfNodePtr>(n)) {
|
||||||
|
|
|
@ -0,0 +1,197 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tools/optimizer/fusion/fullconnected_add_fusion.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/fusion/add_fusion.h"
|
||||||
|
#include "ops/fusion/full_connection.h"
|
||||||
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
#include "nnacl/op_base.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
bool IsPrimitiveProper(const CNodePtr &add_cnode, const CNodePtr &fc_cnode, int index) {
|
||||||
|
auto add_primc = GetValueNode<PrimitiveCPtr>(add_cnode->input(0));
|
||||||
|
MS_CHECK_TRUE_RET(add_primc != nullptr, false);
|
||||||
|
if (IsQuantParameterNode(add_primc)) {
|
||||||
|
MS_LOG(INFO) << add_cnode->fullname_with_scope() << " is quant node";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto add_param_node = add_cnode->input(kInputSizeThree - index);
|
||||||
|
if (!utils::isa<ValueNode>(add_param_node) &&
|
||||||
|
(!utils::isa<Parameter>(add_param_node) || !add_param_node->cast<ParameterPtr>()->default_param())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto abstract = add_param_node->abstract();
|
||||||
|
MS_CHECK_TRUE_RET(abstract != nullptr, false);
|
||||||
|
std::vector<int64_t> bias_shape;
|
||||||
|
if (FetchShapeFromAbstract(abstract, &bias_shape) != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Fetch shape from abstract failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (bias_shape.size() > DIMENSION_1D) {
|
||||||
|
MS_LOG(INFO) << "only support bias with shape size of 1.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fc_cnode->size() > kInputSizeThree) {
|
||||||
|
auto fc_bias_node = fc_cnode->input(kInputIndexThree);
|
||||||
|
if (!utils::isa<ValueNode>(fc_bias_node) &&
|
||||||
|
(!utils::isa<Parameter>(fc_bias_node) || !fc_bias_node->cast<ParameterPtr>()->default_param())) {
|
||||||
|
MS_LOG(INFO) << fc_cnode->fullname_with_scope() << "'s bias is not parameter";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto fc_primc = GetValueNode<std::shared_ptr<ops::FullConnection>>(fc_cnode->input(0));
|
||||||
|
MS_CHECK_TRUE_RET(fc_primc != nullptr, false);
|
||||||
|
if (fc_primc->GetAttr(ops::kActivationType) != nullptr &&
|
||||||
|
fc_primc->get_activation_type() != ActivationType::NO_ACTIVATION) {
|
||||||
|
MS_LOG(INFO) << fc_cnode->fullname_with_scope() << " has activation attr";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (IsQuantParameterNode(fc_primc)) {
|
||||||
|
MS_LOG(INFO) << fc_cnode->fullname_with_scope() << "is quant node";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int CalNewCnodeBias(const AnfNodePtr &add_weight_node, const CNodePtr &fc_cnode) {
|
||||||
|
MS_CHECK_TRUE_RET(add_weight_node != nullptr, RET_ERROR);
|
||||||
|
MS_CHECK_TRUE_RET(fc_cnode != nullptr, RET_ERROR);
|
||||||
|
auto fc_bias_node = fc_cnode->input(kInputIndexThree);
|
||||||
|
MS_CHECK_TRUE_RET(fc_bias_node != nullptr, RET_ERROR);
|
||||||
|
std::shared_ptr<tensor::Tensor> fc_bias_tensor = GetTensorInfo(fc_bias_node);
|
||||||
|
MS_CHECK_TRUE_RET(fc_bias_tensor != nullptr, RET_ERROR);
|
||||||
|
if (fc_bias_tensor->data_type() != kNumberTypeFloat32) {
|
||||||
|
MS_LOG(INFO) << "only support float32 data type";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
std::vector<int64_t> fc_bias_shape = fc_bias_tensor->shape();
|
||||||
|
auto fc_bias_data = reinterpret_cast<float *>(fc_bias_tensor->data_c());
|
||||||
|
MS_CHECK_TRUE_RET(fc_bias_data != nullptr, RET_ERROR);
|
||||||
|
|
||||||
|
std::shared_ptr<tensor::Tensor> add_weight_tensor = GetTensorInfo(add_weight_node);
|
||||||
|
MS_CHECK_TRUE_RET(add_weight_tensor != nullptr, RET_ERROR);
|
||||||
|
if (add_weight_tensor->data_type() != kNumberTypeFloat32) {
|
||||||
|
MS_LOG(INFO) << "only support float32 data type";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
std::vector<int64_t> add_weight_shape = add_weight_tensor->shape();
|
||||||
|
MS_CHECK_TRUE_RET(fc_bias_shape == add_weight_shape, RET_ERROR);
|
||||||
|
auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
|
||||||
|
MS_CHECK_TRUE_RET(add_weight_data != nullptr, RET_ERROR);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < fc_bias_shape[0]; ++i) {
|
||||||
|
fc_bias_data[i] += add_weight_data[i];
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
VectorRef FullconnectedAddFusion::DefineFcAddFusionPattern() const {
|
||||||
|
auto is_fc1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimFullConnection>);
|
||||||
|
MS_CHECK_TRUE_RET(is_fc1 != nullptr, {});
|
||||||
|
auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
|
||||||
|
MS_CHECK_TRUE_RET(is_add != nullptr, {});
|
||||||
|
auto is_seq_var = std::make_shared<SeqVar>();
|
||||||
|
MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
|
||||||
|
return VectorRef({is_add, is_fc1, is_seq_var});
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorRef FullconnectedAddFusion::DefineFcBiasAddPattern() const {
|
||||||
|
auto is_fc1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimFullConnection>);
|
||||||
|
MS_CHECK_TRUE_RET(is_fc1 != nullptr, {});
|
||||||
|
auto is_bias_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBiasAdd>);
|
||||||
|
MS_CHECK_TRUE_RET(is_bias_add != nullptr, {});
|
||||||
|
auto is_seq_var = std::make_shared<SeqVar>();
|
||||||
|
MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
|
||||||
|
return VectorRef({is_bias_add, is_fc1, is_seq_var});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::string, VectorRef> FullconnectedAddFusion::DefinePatterns() const {
|
||||||
|
std::unordered_map<std::string, VectorRef> patterns;
|
||||||
|
patterns["FcAddFusionPatternName"] = DefineFcAddFusionPattern();
|
||||||
|
patterns["FcBiasAddPatternName"] = DefineFcBiasAddPattern();
|
||||||
|
return patterns;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr FullconnectedAddFusion::Process(const std::string &pattern_name, const FuncGraphPtr &func_graph,
|
||||||
|
const AnfNodePtr &node, const EquivPtr &equiv) const {
|
||||||
|
if (func_graph == nullptr || node == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto add_cnode = node->cast<CNodePtr>();
|
||||||
|
MS_CHECK_TRUE_RET(add_cnode != nullptr, nullptr);
|
||||||
|
if (IsMarkedTrainOp(add_cnode)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t index = 0;
|
||||||
|
if (!CheckAndGetCnodeIndex(add_cnode, &index, prim::kPrimFullConnection)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto fc_cnode = add_cnode->input(index)->cast<CNodePtr>();
|
||||||
|
MS_ASSERT(fc_cnode != nullptr);
|
||||||
|
if (IsMarkedTrainOp(fc_cnode)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (IsMultiOutputTensors(func_graph, fc_cnode)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!IsPrimitiveProper(add_cnode, fc_cnode, index)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
auto add_param_node = add_cnode->input(kInputSizeThree - index);
|
||||||
|
MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
|
||||||
|
if (fc_cnode->size() == kInputSizeThree) {
|
||||||
|
manager->AddEdge(fc_cnode, add_param_node);
|
||||||
|
} else if (fc_cnode->size() == kInputSizeFour) {
|
||||||
|
if (CalNewCnodeBias(add_param_node, fc_cnode) != RET_OK) {
|
||||||
|
MS_LOG(INFO) << add_cnode->fullname_with_scope() << " failed to fusion with " << fc_cnode->fullname_with_scope();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (CheckPrimitiveType(node, prim::kPrimAddFusion)) {
|
||||||
|
auto add_primc = GetValueNode<std::shared_ptr<ops::AddFusion>>(add_cnode->input(0));
|
||||||
|
MS_CHECK_TRUE_RET(add_primc != nullptr, nullptr);
|
||||||
|
if (add_primc->GetAttr(ops::kActivationType) != nullptr &&
|
||||||
|
add_primc->get_activation_type() != ActivationType::NO_ACTIVATION) {
|
||||||
|
auto fc_primc = GetValueNode<std::shared_ptr<ops::FullConnection>>(fc_cnode->input(0));
|
||||||
|
MS_CHECK_TRUE_RET(fc_primc != nullptr, nullptr);
|
||||||
|
fc_primc->set_activation_type(add_primc->get_activation_type());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fc_cnode->set_fullname_with_scope(node->fullname_with_scope());
|
||||||
|
(void)manager->Replace(node, fc_cnode);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FULLCONNECTED_ADD_FUSION_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FULLCONNECTED_ADD_FUSION_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||||
|
|
||||||
|
namespace mindspore::opt {
|
||||||
|
class FullconnectedAddFusion : public MultiplePatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit FullconnectedAddFusion(const std::string &name = "FullconnectedAddFusion", bool multigraph = true)
|
||||||
|
: MultiplePatternProcessPass(name, multigraph) {}
|
||||||
|
~FullconnectedAddFusion() override = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
|
||||||
|
VectorRef DefineFcAddFusionPattern() const;
|
||||||
|
VectorRef DefineFcBiasAddPattern() const;
|
||||||
|
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &,
|
||||||
|
const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
} // namespace mindspore::opt
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FULLCONNECTED_ADD_FUSION_H_
|
|
@ -19,7 +19,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "ops/fusion/full_connection.h"
|
#include "ops/fusion/full_connection.h"
|
||||||
#include "ops/fusion/conv2d_fusion.h"
|
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "tools/converter/quant_param_holder.h"
|
#include "tools/converter/quant_param_holder.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
|
|
@ -25,25 +25,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) {
|
|
||||||
MS_ASSERT(cnode != nullptr && index != nullptr);
|
|
||||||
if (cnode->size() != kInputSizeThree) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
size_t matmul_index = 0;
|
|
||||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
|
||||||
if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMulFusion)) {
|
|
||||||
matmul_index = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (matmul_index == 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
*index = matmul_index;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsPrimitiveProper(const CNodePtr &add_cnode, const CNodePtr &matmul_cnode, int index) {
|
bool IsPrimitiveProper(const CNodePtr &add_cnode, const CNodePtr &matmul_cnode, int index) {
|
||||||
auto add_primc = GetValueNode<PrimitiveCPtr>(add_cnode->input(0));
|
auto add_primc = GetValueNode<PrimitiveCPtr>(add_cnode->input(0));
|
||||||
MS_CHECK_TRUE_RET(add_primc != nullptr, false);
|
MS_CHECK_TRUE_RET(add_primc != nullptr, false);
|
||||||
|
@ -169,7 +150,7 @@ AnfNodePtr MatMulAddFusion::Process(const std::string &pattern_name, const FuncG
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
if (!CheckAndGetMatMulIndex(add_cnode, &index)) {
|
if (!CheckAndGetCnodeIndex(add_cnode, &index, prim::kPrimMatMulFusion)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto matmul_cnode = add_cnode->input(index)->cast<CNodePtr>();
|
auto matmul_cnode = add_cnode->input(index)->cast<CNodePtr>();
|
||||||
|
|
Loading…
Reference in New Issue