forked from mindspore-Ecosystem/mindspore
add pattern AdjustAllReduceMulAdd
This commit is contained in:
parent
c9fba7f091
commit
ea6958c50a
|
@ -226,6 +226,7 @@ const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
|
|||
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||
const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
|
||||
|
||||
// Debug ops
|
||||
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
|
||||
|
|
|
@ -231,6 +231,7 @@ extern const PrimitivePtr kPrimMinimumGrad;
|
|||
extern const PrimitivePtr kPrimMaximumGrad;
|
||||
|
||||
// Comm ops
|
||||
extern const PrimitivePtr kPrimAllReduce;
|
||||
extern const PrimitivePtr kPrimMirror;
|
||||
extern const PrimitivePtr kPrimVirtualDiv;
|
||||
extern const PrimitivePtr kPrimVirtualDataset;
|
||||
|
|
|
@ -48,7 +48,7 @@ namespace irpass {
|
|||
OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
|
||||
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
|
||||
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
|
||||
prim::kPrimAddN, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
|
||||
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
||||
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
|
||||
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
|
|
|
@ -228,6 +228,82 @@ class ConstantDuplicateMul : public AnfVisitor {
|
|||
CNodePtr cnode_;
|
||||
};
|
||||
|
||||
// grad = AllReduce(grad) / worker_number
|
||||
// grad = grad + weight * decy
|
||||
// ->
|
||||
// grad = grad + weight * decy
|
||||
// grad = AllReduce(grad) / worker_number
|
||||
|
||||
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
|
||||
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
|
||||
class AdjustAllReduceMulAdd : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
// {prim::kPrimAddN, Zs}
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto addn = node->cast<CNodePtr>();
|
||||
if (addn->size() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
|
||||
if (x_ == nullptr || y_ == nullptr || z_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto fg = node->func_graph();
|
||||
AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg);
|
||||
AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg);
|
||||
AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg);
|
||||
return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg);
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
if (level_ == 0) {
|
||||
level_ = 1;
|
||||
is_reduce_match_ = false;
|
||||
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
|
||||
AnfVisitor::Match(prim::kPrimMul)(node);
|
||||
level_ = 0;
|
||||
if (is_reduce_match_) {
|
||||
y_ = tmp_;
|
||||
} else {
|
||||
z_ = node;
|
||||
}
|
||||
}
|
||||
|
||||
if (level_ == 1) {
|
||||
// {prim::kPrimAllReduce, X}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->size() > 1) {
|
||||
x_ = cnode->input(1);
|
||||
is_reduce_match_ = true;
|
||||
}
|
||||
} else {
|
||||
tmp_ = node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
level_ = 0;
|
||||
is_reduce_match_ = false;
|
||||
x_ = nullptr;
|
||||
y_ = nullptr;
|
||||
z_ = nullptr;
|
||||
tmp_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
int level_{0};
|
||||
bool is_reduce_match_{false};
|
||||
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
|
||||
};
|
||||
|
||||
class ArithmeticSimplify {
|
||||
public:
|
||||
ArithmeticSimplify()
|
||||
|
@ -243,6 +319,7 @@ class ArithmeticSimplify {
|
|||
eliminaters_.emplace_back(identity_);
|
||||
eliminaters_.emplace_back(opt_update_zero_tensor_);
|
||||
eliminaters_.emplace_back(constant_duplicate_mul_);
|
||||
eliminaters_.emplace_back(adjust_allreduce_mul_add_);
|
||||
}
|
||||
~ArithmeticSimplify() = default;
|
||||
|
||||
|
@ -264,6 +341,7 @@ class ArithmeticSimplify {
|
|||
PrimEliminater identity_;
|
||||
OptUpdateZeroTensor opt_update_zero_tensor_;
|
||||
ConstantDuplicateMul constant_duplicate_mul_;
|
||||
AdjustAllReduceMulAdd adjust_allreduce_mul_add_;
|
||||
std::vector<TransformFuncType> eliminaters_{};
|
||||
};
|
||||
} // namespace irpass
|
||||
|
|
|
@ -1235,7 +1235,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|||
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float)
|
||||
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
|
||||
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
|
||||
>>> num_segments = 4
|
||||
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
|
||||
|
|
|
@ -1572,7 +1572,7 @@ class LayerNorm(Primitive):
|
|||
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
|
||||
|
||||
.. math::
|
||||
y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
||||
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
||||
|
||||
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
||||
|
||||
|
|
|
@ -556,5 +556,24 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) {
|
|||
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
|
||||
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
|
||||
FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell");
|
||||
FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr");
|
||||
FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl");
|
||||
FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr");
|
||||
FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1");
|
||||
FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r");
|
||||
FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l");
|
||||
FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2");
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_});
|
||||
ASSERT_TRUE(CheckOpt(beforell, after1, patterns));
|
||||
ASSERT_TRUE(CheckOpt(beforelr, after1, patterns));
|
||||
ASSERT_TRUE(CheckOpt(beforerl, after1, patterns));
|
||||
ASSERT_TRUE(CheckOpt(beforerr, after1, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -908,8 +908,8 @@ def test_print_tuple_wrapper(tag):
|
|||
|
||||
def test_constant_duplicate_mul(tag):
|
||||
fns = FnDict()
|
||||
Mul = Primitive('Mul');
|
||||
Sqrt = Primitive('Sqrt');
|
||||
Mul = Primitive('Mul')
|
||||
Sqrt = Primitive('Sqrt')
|
||||
|
||||
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
|
||||
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||
|
@ -936,3 +936,44 @@ def test_constant_duplicate_mul(tag):
|
|||
return Mul(Sqrt(x), Mul(tensor1, tensor2))
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_adjust_allreduce_mul_add(tag):
|
||||
fns = FnDict()
|
||||
Mul = Primitive('Mul')
|
||||
AddN = Primitive('AddN')
|
||||
AllReduce = Primitive('AllReduce')
|
||||
|
||||
@fns
|
||||
def beforell(x, y, z):
|
||||
return AddN((z, Mul(y, AllReduce(x))))
|
||||
|
||||
@fns
|
||||
def beforelr(x, y, z):
|
||||
return AddN((z, Mul(AllReduce(x), y)))
|
||||
|
||||
@fns
|
||||
def beforerl(x, y, z):
|
||||
return AddN((Mul(y, AllReduce(x)), z))
|
||||
|
||||
@fns
|
||||
def beforerr(x, y, z):
|
||||
return AddN((Mul(AllReduce(x), y), z))
|
||||
|
||||
@fns
|
||||
def after1(x, y, z):
|
||||
return Mul(AllReduce(AddN((z, x))), y)
|
||||
|
||||
@fns
|
||||
def before2r(x, y, z):
|
||||
return AddN((Mul(AllReduce(x), y), Mul(z, z)))
|
||||
|
||||
@fns
|
||||
def before2l(x, y, z):
|
||||
return AddN((Mul(z, z), Mul(AllReduce(x), y)))
|
||||
|
||||
@fns
|
||||
def after2(x, y, z):
|
||||
return Mul(AllReduce(AddN((Mul(z, z), x))), y)
|
||||
|
||||
return fns[tag]
|
||||
|
|
Loading…
Reference in New Issue