From 3db8cfa54f10535e80a017d5e7c99696a9d3df3e Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Sun, 12 Apr 2020 23:18:04 +0800 Subject: [PATCH] add pattern AdjustAllReduceMulAdduse the old opadd test case for bugtemp fix try --- mindspore/ccsrc/operator/ops.cc | 1 + mindspore/ccsrc/operator/ops.h | 1 + mindspore/ccsrc/optimizer/irpass.cc | 1 + mindspore/ccsrc/optimizer/irpass.h | 1 + .../optimizer/irpass/arithmetic_simplify.h | 110 ++++++++++++++++++ .../ccsrc/pipeline/parse/function_block.h | 3 +- mindspore/ccsrc/pipeline/pass.cc | 1 + mindspore/ops/operations/array_ops.py | 2 +- mindspore/ops/operations/nn_ops.py | 2 +- mindspore/ops/primitive.py | 3 +- tests/ut/cpp/optimizer/lib_test.cc | 19 +++ .../gtest_input/optimizer/opt_test.py | 45 ++++++- tests/ut/python/train/test_amp.py | 28 ++++- 13 files changed, 209 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index f024032cda9..2b0eeb26f21 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -241,6 +241,7 @@ const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_dict"); const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); +const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); // Debug ops const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 52c0c1c1333..e6e065f0764 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -245,6 +245,7 @@ extern const PrimitivePtr kPrimInDict; extern const PrimitivePtr kPrimNotInDict; // Comm ops +extern const PrimitivePtr kPrimAllReduce; extern const PrimitivePtr kPrimMirror; extern const PrimitivePtr kPrimVirtualDiv; extern const PrimitivePtr kPrimVirtualDataset; diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 80708740668..d2f7d603590 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -53,6 +53,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); + adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); // ops eliminate item_tuple_eliminate_ = diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h index 02bfee65d6b..e834d69b699 100644 --- a/mindspore/ccsrc/optimizer/irpass.h +++ b/mindspore/ccsrc/optimizer/irpass.h @@ -35,6 +35,7 @@ class OptimizeIRPassLib { SubstitutionPtr arithmetic_simplify_; SubstitutionPtr special_op_eliminate_; SubstitutionPtr zero_like_fill_zero_; + SubstitutionPtr adjust_all_reduce_mul_add_; // ops eliminate SubstitutionPtr item_tuple_eliminate_; diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index ab191aab205..a12ba9128ca 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -228,6 +228,116 @@ class ConstantDuplicateMul : public AnfVisitor { CNodePtr cnode_; }; +// grad = AllReduce(grad) / worker_number +// grad = grad + weight * decy +// -> +// grad = grad + weight * decy +// grad = AllReduce(grad) / worker_number + +// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> +// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} +class AdjustAllReduceMulAdd : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + // {prim::kPrimAddN, Zs} + if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { + return nullptr; + } + auto addn = node->cast(); + if (addn->size() != 2) { + return nullptr; + } + AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); + if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { + return nullptr; + } + auto addn_maketuple = addn->input(1); + + auto fg = all_reduce_fg_; + // addn inputs cross the graph, make the inputs same as allreduce node. + if (z_->isa() && fg != z_->func_graph()) { + auto cnode_z = z_->cast(); + z_ = NewCNode(cnode_z->inputs(), fg); + } + + auto addn_op_node = addn->input(0); + auto make_tuple_op_node = addn->input(1)->cast()->input(0); + + AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); + AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); + AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); + AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); + ProcessDependEdge(fg, addn_maketuple, all_reduce); + return mul; + } + void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) { + // If has dynamic loss scale. + auto &users_map = fg->manager()->node_users(); + auto it = users_map.find(mul_cnode_); + + if (it != users_map.end()) { + auto users = it->second; + for (auto &user_pair : users) { + auto node = user_pair.first; + if (node != addn_maketuple) { + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + fg->manager()->SetEdge(node, user_pair.second, new_node); + } + } + } + } + } + void Visit(const AnfNodePtr &node) override { + if (level_ == 0) { + level_ = 1; + is_reduce_match_ = false; + // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} + AnfVisitor::Match(prim::kPrimMul)(node); + level_ = 0; + if (is_reduce_match_) { + mul_ = node->cast()->input(0); + mul_cnode_ = node->cast(); + y_ = tmp_; + } else { + z_ = node; + } + } + + if (level_ == 1) { + // {prim::kPrimAllReduce, X} + if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { + auto cnode = node->cast(); + if (cnode->size() > 1) { + all_reduce_ = cnode->input(0); + x_ = cnode->input(1); + is_reduce_match_ = true; + all_reduce_fg_ = cnode->func_graph(); + } + } else { + tmp_ = node; + } + } + } + + void Reset() { + level_ = 0; + is_reduce_match_ = false; + x_ = nullptr; + y_ = nullptr; + z_ = nullptr; + tmp_ = nullptr; + all_reduce_fg_ = nullptr; + } + + private: + int level_{0}; + bool is_reduce_match_{false}; + AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; + AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; + FuncGraphPtr all_reduce_fg_{nullptr}; +}; + class ArithmeticSimplify { public: ArithmeticSimplify() diff --git a/mindspore/ccsrc/pipeline/parse/function_block.h b/mindspore/ccsrc/pipeline/parse/function_block.h index e7842903ee2..5341d33b21e 100644 --- a/mindspore/ccsrc/pipeline/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/parse/function_block.h @@ -28,6 +28,7 @@ #include #include "pipeline/parse/parse_base.h" #include "utils/log_adapter.h" +#include "utils/ordered_map.h" namespace mindspore { namespace parse { @@ -99,7 +100,7 @@ class FunctionBlock : public std::enable_shared_from_this { std::unordered_map removable_phis_; // set state nodes need to insert before function return nodes. - std::unordered_map state_assign_; + OrderedMap state_assign_; // hold declared global variables in function std::set global_vars_; diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index 4614f194423..f4a3a49b25f 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -82,6 +82,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { // Arithmetic simplifications irpass.arithmetic_simplify_, irpass.addn_zero_filter_, + irpass.adjust_all_reduce_mul_add_, // Miscellaneous irpass.item_tuple_eliminate_, diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index e8cdbe5e90f..c88735aa233 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1213,7 +1213,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. Examples: - >>> input_x = Tensor([1, 2, 3, 4], mindspore.float) + >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32) >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) >>> num_segments = 4 >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 2f5ca0b55c3..d297c3ad6fd 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1765,7 +1765,7 @@ class LayerNorm(Primitive): `Layer Normalization `_. .. math:: - y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta + y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 95e148204b6..908c0245bb8 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -284,7 +284,8 @@ def prim_attr_register(fn): def constexpr(fn=None, get_instance=True, name=None): """ - Makes a PrimitiveWithInfer operator, which infer the value while compiling. + Makes a PrimitiveWithInfer operator, which infer the value while compiling. We can define a function + to compute between constant variable and used in constructß. Args: fn (function): A `fn` use as the infer_value of the output operator. diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 2d4cf0e78ea..1ed1fed43d1 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -556,5 +556,24 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) { ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); } + +TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { + FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell"); + FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr"); + FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl"); + FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr"); + FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1"); + FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r"); + FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); + FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); + auto patterns = std::vector({irpass.adjust_all_reduce_mul_add_}); + ASSERT_TRUE(CheckOpt(beforell, after1, patterns)); + ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); + ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); + ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); + ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); + ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); +} + } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index 6b302841430..53fcb5dabd0 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -1045,8 +1045,8 @@ def test_print_tuple_wrapper(tag): def test_constant_duplicate_mul(tag): fns = FnDict() - Mul = Primitive('Mul'); - Sqrt = Primitive('Sqrt'); + Mul = Primitive('Mul') + Sqrt = Primitive('Sqrt') x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) @@ -1073,3 +1073,44 @@ def test_constant_duplicate_mul(tag): return Mul(Sqrt(x), Mul(tensor1, tensor2)) return fns[tag] + + +def test_adjust_allreduce_mul_add(tag): + fns = FnDict() + Mul = Primitive('Mul') + AddN = Primitive('AddN') + AllReduce = Primitive('AllReduce') + + @fns + def beforell(x, y, z): + return AddN((z, Mul(y, AllReduce(x)))) + + @fns + def beforelr(x, y, z): + return AddN((z, Mul(AllReduce(x), y))) + + @fns + def beforerl(x, y, z): + return AddN((Mul(y, AllReduce(x)), z)) + + @fns + def beforerr(x, y, z): + return AddN((Mul(AllReduce(x), y), z)) + + @fns + def after1(x, y, z): + return Mul(AllReduce(AddN((z, x))), y) + + @fns + def before2r(x, y, z): + return AddN((Mul(AllReduce(x), y), Mul(z, z))) + + @fns + def before2l(x, y, z): + return AddN((Mul(z, z), Mul(AllReduce(x), y))) + + @fns + def after2(x, y, z): + return Mul(AllReduce(AddN((Mul(z, z), x))), y) + + return fns[tag] diff --git a/tests/ut/python/train/test_amp.py b/tests/ut/python/train/test_amp.py index fe08809be1e..54f2081b6c0 100644 --- a/tests/ut/python/train/test_amp.py +++ b/tests/ut/python/train/test_amp.py @@ -20,9 +20,14 @@ import mindspore.context as context from mindspore import Tensor from mindspore import amp from mindspore import nn -from mindspore.train import Model +from mindspore.train import Model, ParallelMode +from mindspore import Tensor +from mindspore.common import dtype as mstype +import mindspore.context as context +from mindspore.model_zoo.resnet import resnet50 from ....dataset_mock import MindData - +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.communication.management import init def setup_module(module): context.set_context(mode=context.GRAPH_MODE) @@ -138,3 +143,22 @@ def test_compile_model_train_O2(): with pytest.raises(ValueError): # not actual run, the metrics step will fail, check if compile ok. model.eval(dataset) + +def test_compile_model_train_O2_parallel(): + dataset_types = (np.float32, np.float32) + dataset_shapes = ((16, 16), (16, 16)) + + dataset = MindDataSet(dataset_types, dataset_shapes) + + net = NetNoLoss(16, 16) + loss = nn.MSELoss() + optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0) + + context.set_auto_parallel_context( + global_rank=0, device_num=8, + mirror_mean=True, parameter_broadcast=True, + parallel_mode=ParallelMode.DATA_PARALLEL) + init() + + model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2") + model.train(2, dataset, dataset_sink_mode=False)