!198 [opt] momentum duplicate mul constant

Merge pull request !198 from biffex/momentum-duplicate-mul-constant
This commit is contained in:
mindspore-ci-bot 2020-04-10 08:37:41 +08:00 committed by Gitee
commit 1b3b3b1a1c
6 changed files with 112 additions and 4 deletions

View File

@ -45,9 +45,9 @@ namespace mindspore {
namespace opt {
namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(
ArithmeticSimplify(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum});
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});

View File

@ -179,6 +179,55 @@ class OptUpdateZeroTensor : public AnfVisitor {
}
};
// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
class ConstantDuplicateMul : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
// {prim::kPrimMul, Tensor1, {...}}
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node);
if (vnode_ == nullptr || cnode_ == nullptr) {
return nullptr;
}
auto tensor1 = vnode_;
auto mul = cnode_;
Reset();
// {prim::kPrimMul, Tensor2, {...}}
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul);
if (vnode_ == nullptr || cnode_ == nullptr) {
return nullptr;
}
auto tensor2 = vnode_;
auto cnode = cnode_;
auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0));
auto fg = node->func_graph();
auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg);
}
void Visit(const AnfNodePtr &node) override {
if (IsValueNode<tensor::Tensor>(node)) {
vnode_ = node;
}
if (IsCNode(node)) {
cnode_ = node->cast<CNodePtr>();
}
}
void Reset() {
vnode_ = nullptr;
cnode_ = nullptr;
}
private:
AnfNodePtr vnode_;
CNodePtr cnode_;
};
class ArithmeticSimplify {
public:
ArithmeticSimplify()
@ -186,12 +235,14 @@ class ArithmeticSimplify {
add_by_zero_(),
tensor_add_by_zero_(),
identity_(prim::kPrimIdentity),
opt_update_zero_tensor_() {
opt_update_zero_tensor_(),
constant_duplicate_mul_() {
eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(add_by_zero_);
eliminaters_.emplace_back(tensor_add_by_zero_);
eliminaters_.emplace_back(identity_);
eliminaters_.emplace_back(opt_update_zero_tensor_);
eliminaters_.emplace_back(constant_duplicate_mul_);
}
~ArithmeticSimplify() = default;
@ -212,6 +263,7 @@ class ArithmeticSimplify {
TensorAddByZero tensor_add_by_zero_;
PrimEliminater identity_;
OptUpdateZeroTensor opt_update_zero_tensor_;
ConstantDuplicateMul constant_duplicate_mul_;
std::vector<TransformFuncType> eliminaters_{};
};
} // namespace irpass

View File

@ -400,6 +400,8 @@ static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, Fu
auto a2 = GetValueNode(node2);
if (a1->isa<Primitive>() && a2->isa<Primitive>()) {
return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name();
} else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) {
return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>()));
} else {
return *a1 == *a2;
}

View File

@ -774,6 +774,14 @@ class Mul(_MathBinaryOp):
>>> mul(input_x, input_y)
[4, 10, 18]
"""
def infer_value(self, x, y):
if x is not None and y is not None:
x = x.asnumpy()
y = y.asnumpy()
out = x * y
out = np.array(out, x.dtype)
return Tensor(out)
return None
class Square(PrimitiveWithInfer):

View File

@ -543,5 +543,18 @@ TEST_F(TestOptLib, test_print_tuple_wrapper) {
ASSERT_TRUE(CheckOpt(before2, after2, patterns));
ASSERT_TRUE(CheckOpt(before3, before3, patterns));
}
TEST_F(TestOptLib, test_constant_duplicate_mul) {
FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforell");
FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforelr");
FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerl");
FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerr");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "after");
auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_});
ASSERT_TRUE(CheckOpt(beforell, after, patterns));
ASSERT_TRUE(CheckOpt(beforelr, after, patterns));
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
}
} // namespace opt
} // namespace mindspore

View File

@ -16,6 +16,8 @@
from mindspore.ops import Primitive, PrimitiveWithInfer
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore import Tensor
import numpy as np
# pylint: disable=unused-variable
@ -903,3 +905,34 @@ def test_print_tuple_wrapper(tag):
return print_(make_tuple(x, y, z))
return fns[tag]
def test_constant_duplicate_mul(tag):
fns = FnDict()
Mul = Primitive('Mul');
Sqrt = Primitive('Sqrt');
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
tensor2 = Tensor(np.array([[2.2, 3.1], [3.2, 4.2]]).astype('float32'))
@fns
def beforell():
return Mul(tensor1, Mul(tensor2, Sqrt(x)))
@fns
def beforelr():
return Mul(tensor1, Mul(Sqrt(x), tensor2))
@fns
def beforerl():
return Mul(Mul(Sqrt(x), tensor2), tensor1)
@fns
def beforerr():
return Mul(Mul(Sqrt(x), tensor2), tensor1)
@fns
def after():
return Mul(Sqrt(x), Mul(tensor1, tensor2))
return fns[tag]