forked from mindspore-Ecosystem/mindspore
!246 [opt pass] AdjustAllReduceMulAdd
Merge pull request !246 from vlne-v1/I1E3PI-opt-pass-adjust-allreduce-apply-weight-decy-seq
This commit is contained in:
commit
71b63c3fcf
|
@ -230,6 +230,7 @@ const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
|
||||||
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||||
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||||
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||||
|
const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
|
||||||
|
|
||||||
// Debug ops
|
// Debug ops
|
||||||
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
|
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
|
||||||
|
|
|
@ -233,6 +233,7 @@ extern const PrimitivePtr kPrimInDict;
|
||||||
extern const PrimitivePtr kPrimNotInDict;
|
extern const PrimitivePtr kPrimNotInDict;
|
||||||
|
|
||||||
// Comm ops
|
// Comm ops
|
||||||
|
extern const PrimitivePtr kPrimAllReduce;
|
||||||
extern const PrimitivePtr kPrimMirror;
|
extern const PrimitivePtr kPrimMirror;
|
||||||
extern const PrimitivePtr kPrimVirtualDiv;
|
extern const PrimitivePtr kPrimVirtualDiv;
|
||||||
extern const PrimitivePtr kPrimVirtualDataset;
|
extern const PrimitivePtr kPrimVirtualDataset;
|
||||||
|
|
|
@ -48,7 +48,7 @@ namespace irpass {
|
||||||
OptimizeIRPassLib::OptimizeIRPassLib() {
|
OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
|
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
|
||||||
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
|
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
|
||||||
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
|
prim::kPrimAddN, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
|
||||||
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
||||||
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
|
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
|
||||||
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||||
|
|
|
@ -228,6 +228,82 @@ class ConstantDuplicateMul : public AnfVisitor {
|
||||||
CNodePtr cnode_;
|
CNodePtr cnode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// grad = AllReduce(grad) / worker_number
|
||||||
|
// grad = grad + weight * decy
|
||||||
|
// ->
|
||||||
|
// grad = grad + weight * decy
|
||||||
|
// grad = AllReduce(grad) / worker_number
|
||||||
|
|
||||||
|
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
|
||||||
|
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
|
||||||
|
class AdjustAllReduceMulAdd : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
Reset();
|
||||||
|
// {prim::kPrimAddN, Zs}
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto addn = node->cast<CNodePtr>();
|
||||||
|
if (addn->size() != 2) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
|
||||||
|
if (x_ == nullptr || y_ == nullptr || z_ == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fg = node->func_graph();
|
||||||
|
AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg);
|
||||||
|
AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg);
|
||||||
|
AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg);
|
||||||
|
return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Visit(const AnfNodePtr &node) override {
|
||||||
|
if (level_ == 0) {
|
||||||
|
level_ = 1;
|
||||||
|
is_reduce_match_ = false;
|
||||||
|
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
|
||||||
|
AnfVisitor::Match(prim::kPrimMul)(node);
|
||||||
|
level_ = 0;
|
||||||
|
if (is_reduce_match_) {
|
||||||
|
y_ = tmp_;
|
||||||
|
} else {
|
||||||
|
z_ = node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (level_ == 1) {
|
||||||
|
// {prim::kPrimAllReduce, X}
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode->size() > 1) {
|
||||||
|
x_ = cnode->input(1);
|
||||||
|
is_reduce_match_ = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tmp_ = node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reset() {
|
||||||
|
level_ = 0;
|
||||||
|
is_reduce_match_ = false;
|
||||||
|
x_ = nullptr;
|
||||||
|
y_ = nullptr;
|
||||||
|
z_ = nullptr;
|
||||||
|
tmp_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int level_{0};
|
||||||
|
bool is_reduce_match_{false};
|
||||||
|
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
class ArithmeticSimplify {
|
class ArithmeticSimplify {
|
||||||
public:
|
public:
|
||||||
ArithmeticSimplify()
|
ArithmeticSimplify()
|
||||||
|
@ -243,6 +319,7 @@ class ArithmeticSimplify {
|
||||||
eliminaters_.emplace_back(identity_);
|
eliminaters_.emplace_back(identity_);
|
||||||
eliminaters_.emplace_back(opt_update_zero_tensor_);
|
eliminaters_.emplace_back(opt_update_zero_tensor_);
|
||||||
eliminaters_.emplace_back(constant_duplicate_mul_);
|
eliminaters_.emplace_back(constant_duplicate_mul_);
|
||||||
|
eliminaters_.emplace_back(adjust_allreduce_mul_add_);
|
||||||
}
|
}
|
||||||
~ArithmeticSimplify() = default;
|
~ArithmeticSimplify() = default;
|
||||||
|
|
||||||
|
@ -264,6 +341,7 @@ class ArithmeticSimplify {
|
||||||
PrimEliminater identity_;
|
PrimEliminater identity_;
|
||||||
OptUpdateZeroTensor opt_update_zero_tensor_;
|
OptUpdateZeroTensor opt_update_zero_tensor_;
|
||||||
ConstantDuplicateMul constant_duplicate_mul_;
|
ConstantDuplicateMul constant_duplicate_mul_;
|
||||||
|
AdjustAllReduceMulAdd adjust_allreduce_mul_add_;
|
||||||
std::vector<TransformFuncType> eliminaters_{};
|
std::vector<TransformFuncType> eliminaters_{};
|
||||||
};
|
};
|
||||||
} // namespace irpass
|
} // namespace irpass
|
||||||
|
|
|
@ -1235,7 +1235,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
||||||
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
|
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float)
|
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
|
||||||
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
|
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
|
||||||
>>> num_segments = 4
|
>>> num_segments = 4
|
||||||
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
|
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
|
||||||
|
|
|
@ -1622,7 +1622,7 @@ class LayerNorm(Primitive):
|
||||||
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
|
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
||||||
|
|
||||||
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
||||||
|
|
||||||
|
|
|
@ -556,5 +556,24 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) {
|
||||||
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
|
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
|
||||||
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
|
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
|
||||||
|
FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell");
|
||||||
|
FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr");
|
||||||
|
FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl");
|
||||||
|
FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr");
|
||||||
|
FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1");
|
||||||
|
FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r");
|
||||||
|
FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l");
|
||||||
|
FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2");
|
||||||
|
auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_});
|
||||||
|
ASSERT_TRUE(CheckOpt(beforell, after1, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(beforelr, after1, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(beforerl, after1, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(beforerr, after1, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -908,8 +908,8 @@ def test_print_tuple_wrapper(tag):
|
||||||
|
|
||||||
def test_constant_duplicate_mul(tag):
|
def test_constant_duplicate_mul(tag):
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
Mul = Primitive('Mul');
|
Mul = Primitive('Mul')
|
||||||
Sqrt = Primitive('Sqrt');
|
Sqrt = Primitive('Sqrt')
|
||||||
|
|
||||||
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
|
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
|
||||||
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||||
|
@ -936,3 +936,44 @@ def test_constant_duplicate_mul(tag):
|
||||||
return Mul(Sqrt(x), Mul(tensor1, tensor2))
|
return Mul(Sqrt(x), Mul(tensor1, tensor2))
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
||||||
|
|
||||||
|
def test_adjust_allreduce_mul_add(tag):
|
||||||
|
fns = FnDict()
|
||||||
|
Mul = Primitive('Mul')
|
||||||
|
AddN = Primitive('AddN')
|
||||||
|
AllReduce = Primitive('AllReduce')
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def beforell(x, y, z):
|
||||||
|
return AddN((z, Mul(y, AllReduce(x))))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def beforelr(x, y, z):
|
||||||
|
return AddN((z, Mul(AllReduce(x), y)))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def beforerl(x, y, z):
|
||||||
|
return AddN((Mul(y, AllReduce(x)), z))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def beforerr(x, y, z):
|
||||||
|
return AddN((Mul(AllReduce(x), y), z))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after1(x, y, z):
|
||||||
|
return Mul(AllReduce(AddN((z, x))), y)
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before2r(x, y, z):
|
||||||
|
return AddN((Mul(AllReduce(x), y), Mul(z, z)))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before2l(x, y, z):
|
||||||
|
return AddN((Mul(z, z), Mul(AllReduce(x), y)))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after2(x, y, z):
|
||||||
|
return Mul(AllReduce(AddN((Mul(z, z), x))), y)
|
||||||
|
|
||||||
|
return fns[tag]
|
||||||
|
|
Loading…
Reference in New Issue