forked from mindspore-Ecosystem/mindspore
!198 [opt] momentum duplicate mul constant
Merge pull request !198 from biffex/momentum-duplicate-mul-constant
This commit is contained in:
commit
1b3b3b1a1c
|
@ -45,9 +45,9 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace irpass {
|
||||
OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||
arithmetic_simplify_ = MakeSubstitution(
|
||||
ArithmeticSimplify(), "arithmetic_simplify",
|
||||
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum});
|
||||
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
|
||||
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
|
||||
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
|
||||
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
||||
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
|
||||
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
|
|
|
@ -179,6 +179,55 @@ class OptUpdateZeroTensor : public AnfVisitor {
|
|||
}
|
||||
};
|
||||
|
||||
// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
|
||||
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
|
||||
class ConstantDuplicateMul : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
// {prim::kPrimMul, Tensor1, {...}}
|
||||
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node);
|
||||
if (vnode_ == nullptr || cnode_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor1 = vnode_;
|
||||
auto mul = cnode_;
|
||||
|
||||
Reset();
|
||||
// {prim::kPrimMul, Tensor2, {...}}
|
||||
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul);
|
||||
if (vnode_ == nullptr || cnode_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor2 = vnode_;
|
||||
auto cnode = cnode_;
|
||||
|
||||
auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0));
|
||||
auto fg = node->func_graph();
|
||||
auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
|
||||
return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg);
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
if (IsValueNode<tensor::Tensor>(node)) {
|
||||
vnode_ = node;
|
||||
}
|
||||
|
||||
if (IsCNode(node)) {
|
||||
cnode_ = node->cast<CNodePtr>();
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
vnode_ = nullptr;
|
||||
cnode_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtr vnode_;
|
||||
CNodePtr cnode_;
|
||||
};
|
||||
|
||||
class ArithmeticSimplify {
|
||||
public:
|
||||
ArithmeticSimplify()
|
||||
|
@ -186,12 +235,14 @@ class ArithmeticSimplify {
|
|||
add_by_zero_(),
|
||||
tensor_add_by_zero_(),
|
||||
identity_(prim::kPrimIdentity),
|
||||
opt_update_zero_tensor_() {
|
||||
opt_update_zero_tensor_(),
|
||||
constant_duplicate_mul_() {
|
||||
eliminaters_.emplace_back(multiply_by_zero_or_one_);
|
||||
eliminaters_.emplace_back(add_by_zero_);
|
||||
eliminaters_.emplace_back(tensor_add_by_zero_);
|
||||
eliminaters_.emplace_back(identity_);
|
||||
eliminaters_.emplace_back(opt_update_zero_tensor_);
|
||||
eliminaters_.emplace_back(constant_duplicate_mul_);
|
||||
}
|
||||
~ArithmeticSimplify() = default;
|
||||
|
||||
|
@ -212,6 +263,7 @@ class ArithmeticSimplify {
|
|||
TensorAddByZero tensor_add_by_zero_;
|
||||
PrimEliminater identity_;
|
||||
OptUpdateZeroTensor opt_update_zero_tensor_;
|
||||
ConstantDuplicateMul constant_duplicate_mul_;
|
||||
std::vector<TransformFuncType> eliminaters_{};
|
||||
};
|
||||
} // namespace irpass
|
||||
|
|
|
@ -400,6 +400,8 @@ static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, Fu
|
|||
auto a2 = GetValueNode(node2);
|
||||
if (a1->isa<Primitive>() && a2->isa<Primitive>()) {
|
||||
return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name();
|
||||
} else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) {
|
||||
return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>()));
|
||||
} else {
|
||||
return *a1 == *a2;
|
||||
}
|
||||
|
|
|
@ -774,6 +774,14 @@ class Mul(_MathBinaryOp):
|
|||
>>> mul(input_x, input_y)
|
||||
[4, 10, 18]
|
||||
"""
|
||||
def infer_value(self, x, y):
|
||||
if x is not None and y is not None:
|
||||
x = x.asnumpy()
|
||||
y = y.asnumpy()
|
||||
out = x * y
|
||||
out = np.array(out, x.dtype)
|
||||
return Tensor(out)
|
||||
return None
|
||||
|
||||
|
||||
class Square(PrimitiveWithInfer):
|
||||
|
|
|
@ -543,5 +543,18 @@ TEST_F(TestOptLib, test_print_tuple_wrapper) {
|
|||
ASSERT_TRUE(CheckOpt(before2, after2, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before3, before3, patterns));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, test_constant_duplicate_mul) {
|
||||
FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforell");
|
||||
FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforelr");
|
||||
FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerl");
|
||||
FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerr");
|
||||
FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "after");
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_});
|
||||
ASSERT_TRUE(CheckOpt(beforell, after, patterns));
|
||||
ASSERT_TRUE(CheckOpt(beforelr, after, patterns));
|
||||
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
|
||||
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
from mindspore.ops import Primitive, PrimitiveWithInfer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore import Tensor
|
||||
import numpy as np
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
|
||||
|
@ -903,3 +905,34 @@ def test_print_tuple_wrapper(tag):
|
|||
return print_(make_tuple(x, y, z))
|
||||
|
||||
return fns[tag]
|
||||
|
||||
def test_constant_duplicate_mul(tag):
|
||||
fns = FnDict()
|
||||
Mul = Primitive('Mul');
|
||||
Sqrt = Primitive('Sqrt');
|
||||
|
||||
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
|
||||
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||
tensor2 = Tensor(np.array([[2.2, 3.1], [3.2, 4.2]]).astype('float32'))
|
||||
|
||||
@fns
|
||||
def beforell():
|
||||
return Mul(tensor1, Mul(tensor2, Sqrt(x)))
|
||||
|
||||
@fns
|
||||
def beforelr():
|
||||
return Mul(tensor1, Mul(Sqrt(x), tensor2))
|
||||
|
||||
@fns
|
||||
def beforerl():
|
||||
return Mul(Mul(Sqrt(x), tensor2), tensor1)
|
||||
|
||||
@fns
|
||||
def beforerr():
|
||||
return Mul(Mul(Sqrt(x), tensor2), tensor1)
|
||||
|
||||
@fns
|
||||
def after():
|
||||
return Mul(Sqrt(x), Mul(tensor1, tensor2))
|
||||
|
||||
return fns[tag]
|
||||
|
|
Loading…
Reference in New Issue