Add matmul biasadd fusion pass

This commit is contained in:
GinFung 2020-03-30 09:54:05 +08:00
parent 930a1fb0a8
commit 468dbc3557
7 changed files with 192 additions and 0 deletions

View File

@ -43,6 +43,7 @@
#include "pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h"
#include "pre_activate/ascend/ir_fusion/mul_add_fusion.h"
#include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h"
#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h"
#include "pre_activate/ascend/format_type/insert_trans_op.h"
#include "pre_activate/pass/getitem_tuple.h"
#include "pre_activate/pass/optimize_dependence.h"
@ -173,6 +174,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
}

View File

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h"
#include <memory>
#include "pre_activate/common/helper.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kMatMulInputIndex = 1;
constexpr size_t kBiasInputIndex = 2;
} // namespace
const BaseRef MatmulBiasaddFusion::DefinePattern() const {
VarPtr X0 = std::make_shared<Var>();
VarPtr X1 = std::make_shared<Var>();
VarPtr X2 = std::make_shared<Var>();
const auto prim_bias_add = std::make_shared<Primitive>(kBiasAddOpName);
return VectorRef({prim_bias_add, VectorRef({prim::kPrimMatMul, X0, X1}), X2});
}
const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, kBiasAddInputNum);
AnfNodePtr matmul = cnode->input(kMatMulInputIndex);
MS_EXCEPTION_IF_NULL(matmul);
auto matmul_cnode = matmul->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(matmul_cnode);
matmul_cnode->add_input(cnode->input(kBiasInputIndex));
AnfAlgo::SetNodeAttr(kAttrHasBias, MakeValue(true), matmul);
return matmul;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_
#include "pre_activate/common/optimizer.h"
namespace mindspore {
namespace opt {
class MatmulBiasaddFusion : public PatternProcessPass {
public:
explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {}
~MatmulBiasaddFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_

View File

@ -84,6 +84,7 @@ constexpr size_t kLayerNormGradInputNum = 6;
constexpr size_t kAdamApplyOneOutputNum = 3;
constexpr size_t kBackendTransDataInputNum = 2;
constexpr size_t kApplyMomentumInputNum = 6;
constexpr size_t kBiasAddInputNum = 3;
enum FusedBatchNormInput {
kX = 1,

View File

@ -110,6 +110,7 @@ constexpr auto kResizeNearestNeighborGrad = "ResizeNearestNeighborGrad";
constexpr auto kFusedMulAddOpName = "FusedMulAdd";
constexpr auto kFusedMulAddNOpName = "FusedMulAddN";
constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum";
constexpr auto kBiasAddOpName = "BiasAdd";
// attr key name
constexpr auto kAttrInputNames = "input_names";
@ -140,6 +141,7 @@ constexpr auto kAttrDynInput = "dynamic";
constexpr auto kAttrDynInputSizes = "dyn_input_sizes";
constexpr auto kAttrSrcFormat = "src_format";
constexpr auto kAttrOutputUsedNum = "output_used_num";
constexpr auto kAttrHasBias = "has_bias";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";

View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h"
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
namespace mindspore {
namespace opt {
class TestHWMatmulBiasaddFusion : public BackendCommon {
public:
TestHWMatmulBiasaddFusion() : get_py_fun_("gtest_input.pre_activate.matmul_biasadd_fusion_test", true) {}
~TestHWMatmulBiasaddFusion() override = default;
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWMatmulBiasaddFusion, test_matmul_biasadd_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_matmul_biasadd_fusion", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shpx{1, 3};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shpx);
std::vector<int> shpy{3, 4};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shpy);
std::vector<int> shp_bias{4};
auto bias_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_bias);
AbstractBasePtrList args_spec_list;
args_spec_list.push_back(x_abstract);
args_spec_list.push_back(y_abstract);
args_spec_list.push_back(bias_abstract);
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::MatmulBiasaddFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_matmul_biasadd_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,46 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import operations as P
from mindspore.ops import Primitive
MatMul = P.MatMul()
BiasAdd = P.BiasAdd()
make_tuple = Primitive('make_tuple')
class FnDict:
def __init__(self):
self.fnDict = {}
def __call__(self, fn):
self.fnDict[fn.__name__] = fn
def __getitem__(self, name):
return self.fnDict[name]
def test_matmul_biasadd_fusion(tag):
fns = FnDict()
@fns
def before(input0, input1, input2):
matmul = MatMul(input0, input1)
biasadd = BiasAdd(matmul, input2)
return biasadd
@fns
def after(input0, input1, input2):
return make_tuple(MatMul(input0, input1, input2))
return fns[tag]