From 62bbf560c66e7dee5a8c555feb8aa3fa13d118ca Mon Sep 17 00:00:00 2001 From: biffex Date: Thu, 9 Apr 2020 15:04:24 +0800 Subject: [PATCH] constant duplicate mul for momentum --- mindspore/ccsrc/optimizer/irpass.cc | 6 +-- .../optimizer/irpass/arithmetic_simplify.h | 54 ++++++++++++++++++- mindspore/ccsrc/utils/graph_utils.cc | 2 + mindspore/ops/operations/math_ops.py | 8 +++ tests/ut/cpp/optimizer/lib_test.cc | 13 +++++ .../gtest_input/optimizer/opt_test.py | 33 ++++++++++++ 6 files changed, 112 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index cdc960792ff..0991c31b00a 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -45,9 +45,9 @@ namespace mindspore { namespace opt { namespace irpass { OptimizeIRPassLib::OptimizeIRPassLib() { - arithmetic_simplify_ = MakeSubstitution( - ArithmeticSimplify(), "arithmetic_simplify", - {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum}); + arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", + {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, + prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index 8c5610ed1b3..ab191aab205 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -179,6 +179,55 @@ class OptUpdateZeroTensor : public AnfVisitor { } }; +// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> +// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} +class ConstantDuplicateMul : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + // {prim::kPrimMul, Tensor1, {...}} + AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); + if (vnode_ == nullptr || cnode_ == nullptr) { + return nullptr; + } + auto tensor1 = vnode_; + auto mul = cnode_; + + Reset(); + // {prim::kPrimMul, Tensor2, {...}} + AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); + if (vnode_ == nullptr || cnode_ == nullptr) { + return nullptr; + } + auto tensor2 = vnode_; + auto cnode = cnode_; + + auto PrimMul = GetValueNode(mul->input(0)); + auto fg = node->func_graph(); + auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); + return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg); + } + + void Visit(const AnfNodePtr &node) override { + if (IsValueNode(node)) { + vnode_ = node; + } + + if (IsCNode(node)) { + cnode_ = node->cast(); + } + } + + void Reset() { + vnode_ = nullptr; + cnode_ = nullptr; + } + + private: + AnfNodePtr vnode_; + CNodePtr cnode_; +}; + class ArithmeticSimplify { public: ArithmeticSimplify() @@ -186,12 +235,14 @@ class ArithmeticSimplify { add_by_zero_(), tensor_add_by_zero_(), identity_(prim::kPrimIdentity), - opt_update_zero_tensor_() { + opt_update_zero_tensor_(), + constant_duplicate_mul_() { eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(add_by_zero_); eliminaters_.emplace_back(tensor_add_by_zero_); eliminaters_.emplace_back(identity_); eliminaters_.emplace_back(opt_update_zero_tensor_); + eliminaters_.emplace_back(constant_duplicate_mul_); } ~ArithmeticSimplify() = default; @@ -212,6 +263,7 @@ class ArithmeticSimplify { TensorAddByZero tensor_add_by_zero_; PrimEliminater identity_; OptUpdateZeroTensor opt_update_zero_tensor_; + ConstantDuplicateMul constant_duplicate_mul_; std::vector eliminaters_{}; }; } // namespace irpass diff --git a/mindspore/ccsrc/utils/graph_utils.cc b/mindspore/ccsrc/utils/graph_utils.cc index 938df2c291c..55ef8dc3d5a 100644 --- a/mindspore/ccsrc/utils/graph_utils.cc +++ b/mindspore/ccsrc/utils/graph_utils.cc @@ -400,6 +400,8 @@ static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, Fu auto a2 = GetValueNode(node2); if (a1->isa() && a2->isa()) { return a1->cast()->name() == a2->cast()->name(); + } else if (a1->isa() && a2->isa()) { + return a1->cast()->ValueEqual(*(a2->cast())); } else { return *a1 == *a2; } diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index d003f6ee8b3..e5e89615dff 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -771,6 +771,14 @@ class Mul(_MathBinaryOp): >>> mul(input_x, input_y) [4, 10, 18] """ + def infer_value(self, x, y): + if x is not None and y is not None: + x = x.asnumpy() + y = y.asnumpy() + out = x * y + out = np.array(out, x.dtype) + return Tensor(out) + return None class Square(PrimitiveWithInfer): diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index ff3c00d37a2..2d4cf0e78ea 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -543,5 +543,18 @@ TEST_F(TestOptLib, test_print_tuple_wrapper) { ASSERT_TRUE(CheckOpt(before2, after2, patterns)); ASSERT_TRUE(CheckOpt(before3, before3, patterns)); } + +TEST_F(TestOptLib, test_constant_duplicate_mul) { + FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforell"); + FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforelr"); + FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerl"); + FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerr"); + FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "after"); + auto patterns = std::vector({irpass.arithmetic_simplify_}); + ASSERT_TRUE(CheckOpt(beforell, after, patterns)); + ASSERT_TRUE(CheckOpt(beforelr, after, patterns)); + ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); + ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index 53eb2130f00..d494ad27d38 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -16,6 +16,8 @@ from mindspore.ops import Primitive, PrimitiveWithInfer from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G +from mindspore import Tensor +import numpy as np # pylint: disable=unused-variable @@ -903,3 +905,34 @@ def test_print_tuple_wrapper(tag): return print_(make_tuple(x, y, z)) return fns[tag] + +def test_constant_duplicate_mul(tag): + fns = FnDict() + Mul = Primitive('Mul'); + Sqrt = Primitive('Sqrt'); + + x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) + tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) + tensor2 = Tensor(np.array([[2.2, 3.1], [3.2, 4.2]]).astype('float32')) + + @fns + def beforell(): + return Mul(tensor1, Mul(tensor2, Sqrt(x))) + + @fns + def beforelr(): + return Mul(tensor1, Mul(Sqrt(x), tensor2)) + + @fns + def beforerl(): + return Mul(Mul(Sqrt(x), tensor2), tensor1) + + @fns + def beforerr(): + return Mul(Mul(Sqrt(x), tensor2), tensor1) + + @fns + def after(): + return Mul(Sqrt(x), Mul(tensor1, tensor2)) + + return fns[tag]