diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index ed287b03b78..12e6b70a6f7 100644 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -223,7 +223,6 @@ const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); -const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); // Debug ops const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 6c88e30e701..5fbf2b70679 100644 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -228,7 +228,6 @@ extern const PrimitivePtr kPrimMinimumGrad; extern const PrimitivePtr kPrimMaximumGrad; // Comm ops -extern const PrimitivePtr kPrimAllReduce; extern const PrimitivePtr kPrimMirror; extern const PrimitivePtr kPrimVirtualDiv; extern const PrimitivePtr kPrimVirtualDataset; diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 25784c8a0b0..0991c31b00a 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -47,7 +47,7 @@ namespace irpass { OptimizeIRPassLib::OptimizeIRPassLib() { arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, - prim::kPrimAddN, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); + prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index 0d48fc1463b..ab191aab205 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -228,82 +228,6 @@ class ConstantDuplicateMul : public AnfVisitor { CNodePtr cnode_; }; -// grad = AllReduce(grad) / worker_number -// grad = grad + weight * decy -// -> -// grad = grad + weight * decy -// grad = AllReduce(grad) / worker_number - -// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> -// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} -class AdjustAllReduceMulAdd : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - // {prim::kPrimAddN, Zs} - if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { - return nullptr; - } - auto addn = node->cast(); - if (addn->size() != 2) { - return nullptr; - } - - AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); - if (x_ == nullptr || y_ == nullptr || z_ == nullptr) { - return nullptr; - } - - auto fg = node->func_graph(); - AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg); - AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg); - AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg); - return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg); - } - - void Visit(const AnfNodePtr &node) override { - if (level_ == 0) { - level_ = 1; - is_reduce_match_ = false; - // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} - AnfVisitor::Match(prim::kPrimMul)(node); - level_ = 0; - if (is_reduce_match_) { - y_ = tmp_; - } else { - z_ = node; - } - } - - if (level_ == 1) { - // {prim::kPrimAllReduce, X} - if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { - auto cnode = node->cast(); - if (cnode->size() > 1) { - x_ = cnode->input(1); - is_reduce_match_ = true; - } - } else { - tmp_ = node; - } - } - } - - void Reset() { - level_ = 0; - is_reduce_match_ = false; - x_ = nullptr; - y_ = nullptr; - z_ = nullptr; - tmp_ = nullptr; - } - - private: - int level_{0}; - bool is_reduce_match_{false}; - AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; -}; - class ArithmeticSimplify { public: ArithmeticSimplify() @@ -319,7 +243,6 @@ class ArithmeticSimplify { eliminaters_.emplace_back(identity_); eliminaters_.emplace_back(opt_update_zero_tensor_); eliminaters_.emplace_back(constant_duplicate_mul_); - eliminaters_.emplace_back(adjust_allreduce_mul_add_); } ~ArithmeticSimplify() = default; @@ -341,7 +264,6 @@ class ArithmeticSimplify { PrimEliminater identity_; OptUpdateZeroTensor opt_update_zero_tensor_; ConstantDuplicateMul constant_duplicate_mul_; - AdjustAllReduceMulAdd adjust_allreduce_mul_add_; std::vector eliminaters_{}; }; } // namespace irpass diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2264e727737..ac7f8ed699c 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1235,8 +1235,8 @@ class UnsortedSegmentSum(PrimitiveWithInfer): Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. Examples: - >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32) - >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) + >>> input_x = [1, 2, 3, 4] + >>> segment_ids = [0, 0, 1, 2] >>> num_segments = 4 >>> type = P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) """ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index e780500af7d..3cc67184847 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1499,7 +1499,7 @@ class LayerNorm(Primitive): `Layer Normalization `_. .. math:: - y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta + y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 8e348c698a6..2d4cf0e78ea 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -556,24 +556,5 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) { ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); } - -TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { - FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell"); - FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr"); - FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl"); - FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr"); - FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1"); - FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r"); - FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); - FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); - auto patterns = std::vector({irpass.arithmetic_simplify_}); - ASSERT_TRUE(CheckOpt(beforell, after1, patterns)); - ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); - ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); - ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); - ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); - ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); -} - } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index d74aa159521..d494ad27d38 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -908,8 +908,8 @@ def test_print_tuple_wrapper(tag): def test_constant_duplicate_mul(tag): fns = FnDict() - Mul = Primitive('Mul') - Sqrt = Primitive('Sqrt') + Mul = Primitive('Mul'); + Sqrt = Primitive('Sqrt'); x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) @@ -936,44 +936,3 @@ def test_constant_duplicate_mul(tag): return Mul(Sqrt(x), Mul(tensor1, tensor2)) return fns[tag] - - -def test_adjust_allreduce_mul_add(tag): - fns = FnDict() - Mul = Primitive('Mul') - AddN = Primitive('AddN') - AllReduce = Primitive('AllReduce') - - @fns - def beforell(x, y, z): - return AddN((z, Mul(y, AllReduce(x)))) - - @fns - def beforelr(x, y, z): - return AddN((z, Mul(AllReduce(x), y))) - - @fns - def beforerl(x, y, z): - return AddN((Mul(y, AllReduce(x)), z)) - - @fns - def beforerr(x, y, z): - return AddN((Mul(AllReduce(x), y), z)) - - @fns - def after1(x, y, z): - return Mul(AllReduce(AddN((z, x))), y) - - @fns - def before2r(x, y, z): - return AddN((Mul(AllReduce(x), y), Mul(z, z))) - - @fns - def before2l(x, y, z): - return AddN((Mul(z, z), Mul(AllReduce(x), y))) - - @fns - def after2(x, y, z): - return Mul(AllReduce(AddN((Mul(z, z), x))), y) - - return fns[tag]