forked from mindspore-Ecosystem/mindspore
mul_add_fusion pass supports when add's 2nd is mul
This commit is contained in:
parent
c176bbe4c8
commit
12eaaf710b
|
@ -24,40 +24,57 @@
|
||||||
#include "pre_activate/common/helper.h"
|
#include "pre_activate/common/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) {
|
||||||
const BaseRef MulAddFusion::DefinePattern() const {
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
VarPtr mul_x_ = std::make_shared<Var>();
|
MS_EXCEPTION_IF_NULL(add);
|
||||||
VarPtr mul_y_ = std::make_shared<Var>();
|
|
||||||
VarPtr add_y_ = std::make_shared<Var>();
|
|
||||||
|
|
||||||
VectorRef mul({prim::kPrimMul, mul_x_, mul_y_});
|
for (size_t index = 1; index < add->size(); ++index) {
|
||||||
VectorRef add({prim::kPrimTensorAdd, mul, add_y_});
|
auto input = add->input(index);
|
||||||
return add;
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
if (input->isa<CNode>()) {
|
||||||
|
auto cnode = input->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) {
|
||||||
|
if (!opt::IsUsedByOthers(graph, cnode)) {
|
||||||
|
*mul = cnode;
|
||||||
|
*mul_index = index;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const {
|
namespace opt {
|
||||||
if (graph == nullptr || node == nullptr || equiv == nullptr) {
|
const BaseRef MulAddFusion::DefinePattern() const {
|
||||||
|
VarPtr x = std::make_shared<Var>();
|
||||||
|
VarPtr y = std::make_shared<Var>();
|
||||||
|
VectorRef pattern({prim::kPrimTensorAdd, x, y});
|
||||||
|
return pattern;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||||
|
if (graph == nullptr || node == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto add = node->cast<CNodePtr>();
|
auto add = node->cast<CNodePtr>();
|
||||||
if (add == nullptr || add->inputs().size() != kAddInputNum) {
|
if (add == nullptr || add->inputs().size() != kAddInputNum) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto mul_anf = add->input(1);
|
CNodePtr mul = nullptr;
|
||||||
if (mul_anf == nullptr) {
|
size_t mul_index = 0;
|
||||||
return nullptr;
|
if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) {
|
||||||
}
|
MS_LOG(DEBUG) << "Cannot find used-by-only-one-op Mul in Add's inputs";
|
||||||
auto mul = mul_anf->cast<CNodePtr>();
|
|
||||||
if (mul == nullptr || mul->inputs().size() != kMulInputNum) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (IsUsedByOthers(graph, mul)) {
|
|
||||||
MS_LOG(DEBUG) << "Mul is used by more then two nodes, cannot fuse";
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto prim = std::make_shared<Primitive>(kFusedMulAddOpName);
|
auto prim = std::make_shared<Primitive>(kFusedMulAddOpName);
|
||||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), mul->input(1), mul->input(2), add->input(2)};
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||||
|
for (size_t index = 1; index < mul->size(); ++index) {
|
||||||
|
inputs.push_back(mul->input(index));
|
||||||
|
}
|
||||||
|
inputs.push_back(add->input(add->size() - mul_index));
|
||||||
auto fusion_node = graph->NewCNode(inputs);
|
auto fusion_node = graph->NewCNode(inputs);
|
||||||
fusion_node->set_scope(add->scope());
|
fusion_node->set_scope(add->scope());
|
||||||
fusion_node->set_abstract(add->abstract());
|
fusion_node->set_abstract(add->abstract());
|
||||||
|
|
|
@ -28,8 +28,28 @@ class TestHWMulAddFusion : public BackendCommon {
|
||||||
UT::PyFuncGraphFetcher get_py_fun_;
|
UT::PyFuncGraphFetcher get_py_fun_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestHWMulAddFusion, test_mul_add_fusion) {
|
TEST_F(TestHWMulAddFusion, test_mul_add_fusion1) {
|
||||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before");
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before1");
|
||||||
|
std::vector<int> shp{2, 32, 224, 224};
|
||||||
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||||
|
AbstractBasePtrList args_spec_list;
|
||||||
|
for (size_t i = 0; i < 3; ++i) {
|
||||||
|
args_spec_list.push_back(x_abstract);
|
||||||
|
}
|
||||||
|
auto fg = GetKernelGraph(g, args_spec_list);
|
||||||
|
|
||||||
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
|
pm->AddPass(std::make_shared<opt::MulAddFusion>());
|
||||||
|
optimizer->AddPassManager(pm);
|
||||||
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||||
|
|
||||||
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "after");
|
||||||
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestHWMulAddFusion, test_mul_add_fusion2) {
|
||||||
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before2");
|
||||||
std::vector<int> shp{2, 32, 224, 224};
|
std::vector<int> shp{2, 32, 224, 224};
|
||||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
|
|
|
@ -21,7 +21,6 @@ fused_mul_add = Primitive('FusedMulAdd')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive('tuple_getitem')
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.fnDict = {}
|
self.fnDict = {}
|
||||||
|
@ -32,16 +31,21 @@ class FnDict:
|
||||||
def __getitem__(self, name):
|
def __getitem__(self, name):
|
||||||
return self.fnDict[name]
|
return self.fnDict[name]
|
||||||
|
|
||||||
|
|
||||||
def test_mul_add_fusion(tag):
|
def test_mul_add_fusion(tag):
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def before(x, y, z):
|
def before1(x, y, z):
|
||||||
res = mul(x, y)
|
res = mul(x, y)
|
||||||
res = add(res, z)
|
res = add(res, z)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before2(x, y, z):
|
||||||
|
res = mul(x, y)
|
||||||
|
res = add(z, res)
|
||||||
|
return res
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def after(x, y, z):
|
def after(x, y, z):
|
||||||
res = fused_mul_add(x, y, z)
|
res = fused_mul_add(x, y, z)
|
||||||
|
|
Loading…
Reference in New Issue