forked from mindspore-Ecosystem/mindspore
!993 [IRFusion pass enhancement] mul_add_fusion pass supports when add's 2nd is mul
Merge pull request !993 from huanghui/mul-add-fusion-pass-enhancement
This commit is contained in:
commit
a55cb95589
|
@ -24,40 +24,57 @@
|
|||
#include "pre_activate/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef MulAddFusion::DefinePattern() const {
|
||||
VarPtr mul_x_ = std::make_shared<Var>();
|
||||
VarPtr mul_y_ = std::make_shared<Var>();
|
||||
VarPtr add_y_ = std::make_shared<Var>();
|
||||
bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(add);
|
||||
|
||||
VectorRef mul({prim::kPrimMul, mul_x_, mul_y_});
|
||||
VectorRef add({prim::kPrimTensorAdd, mul, add_y_});
|
||||
return add;
|
||||
for (size_t index = 1; index < add->size(); ++index) {
|
||||
auto input = add->input(index);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->isa<CNode>()) {
|
||||
auto cnode = input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) {
|
||||
if (!opt::IsUsedByOthers(graph, cnode)) {
|
||||
*mul = cnode;
|
||||
*mul_index = index;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const {
|
||||
if (graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||
namespace opt {
|
||||
const BaseRef MulAddFusion::DefinePattern() const {
|
||||
VarPtr x = std::make_shared<Var>();
|
||||
VarPtr y = std::make_shared<Var>();
|
||||
VectorRef pattern({prim::kPrimTensorAdd, x, y});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
if (graph == nullptr || node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto add = node->cast<CNodePtr>();
|
||||
if (add == nullptr || add->inputs().size() != kAddInputNum) {
|
||||
return nullptr;
|
||||
}
|
||||
auto mul_anf = add->input(1);
|
||||
if (mul_anf == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto mul = mul_anf->cast<CNodePtr>();
|
||||
if (mul == nullptr || mul->inputs().size() != kMulInputNum) {
|
||||
return nullptr;
|
||||
}
|
||||
if (IsUsedByOthers(graph, mul)) {
|
||||
MS_LOG(DEBUG) << "Mul is used by more then two nodes, cannot fuse";
|
||||
CNodePtr mul = nullptr;
|
||||
size_t mul_index = 0;
|
||||
if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) {
|
||||
MS_LOG(DEBUG) << "Cannot find used-by-only-one-op Mul in Add's inputs";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedMulAddOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), mul->input(1), mul->input(2), add->input(2)};
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
for (size_t index = 1; index < mul->size(); ++index) {
|
||||
inputs.push_back(mul->input(index));
|
||||
}
|
||||
inputs.push_back(add->input(add->size() - mul_index));
|
||||
auto fusion_node = graph->NewCNode(inputs);
|
||||
fusion_node->set_scope(add->scope());
|
||||
fusion_node->set_abstract(add->abstract());
|
||||
|
|
|
@ -28,8 +28,28 @@ class TestHWMulAddFusion : public BackendCommon {
|
|||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWMulAddFusion, test_mul_add_fusion) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before");
|
||||
TEST_F(TestHWMulAddFusion, test_mul_add_fusion1) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before1");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::MulAddFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWMulAddFusion, test_mul_add_fusion2) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before2");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
|
|
@ -21,7 +21,6 @@ fused_mul_add = Primitive('FusedMulAdd')
|
|||
make_tuple = Primitive('make_tuple')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fnDict = {}
|
||||
|
@ -32,16 +31,21 @@ class FnDict:
|
|||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_mul_add_fusion(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(x, y, z):
|
||||
def before1(x, y, z):
|
||||
res = mul(x, y)
|
||||
res = add(res, z)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def before2(x, y, z):
|
||||
res = mul(x, y)
|
||||
res = add(z, res)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x, y, z):
|
||||
res = fused_mul_add(x, y, z)
|
||||
|
|
Loading…
Reference in New Issue