diff --git a/mindspore/ccsrc/frontend/optimizer/recompute.cc b/mindspore/ccsrc/frontend/optimizer/recompute.cc index 13b408c5ab6..ce5896f179a 100644 --- a/mindspore/ccsrc/frontend/optimizer/recompute.cc +++ b/mindspore/ccsrc/frontend/optimizer/recompute.cc @@ -33,8 +33,8 @@ namespace { constexpr auto kGradientsFlag = "Gradients"; bool CanNotRecomputed(const CNodePtr &node) { - static std::unordered_set not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask, - prim::kPrimLoad, prim::kPrimTupleGetItem}; + static std::unordered_set not_recomputed_op_list{prim::kPrimDropoutGenMask, prim::kPrimLoad, + prim::kPrimTupleGetItem}; return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(), [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index c0c89beb245..78b5ddd034e 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -134,6 +134,7 @@ constexpr char FUSION[] = "fusion"; constexpr char DO_MIRROR[] = "do_mirror"; constexpr char RECOMPUTE[] = "recompute"; constexpr char RECOMPUTE_COMM_OP[] = "recompute_comm_op"; +constexpr char NOT_RECOMPUTE[] = "not_recompute"; constexpr char NUM_SAMPLED[] = "num_sampled"; constexpr char NUM_TRUE[] = "num_true"; constexpr char SEED[] = "seed"; @@ -193,7 +194,7 @@ constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; constexpr char FIELD_SIZE[] = "field_size"; constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; constexpr char DEVICE[] = "Device"; -constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather"; +constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather_not_recompute"; constexpr char CELLLIST_KEYWORD_PATTERN[] = "-CellList/(\\d+)-"; constexpr char OUT_CHANNEL[] = "out_channel"; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 043f8dd9833..aac8dd49f76 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -261,6 +261,9 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An PrimitivePtr new_node_prim = new_node_value->value()->cast(); new_node_prim->set_instance_name(instance_name); new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); + if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) { + new_node_prim->set_attr("recompute", MakeValue(false)); + } new_node->set_scope(scope); node_input[0]->set_scope(scope); manager->SetEdge(node, SizeToLong(index), new_node); @@ -290,6 +293,9 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons auto new_node_prim = GetValueNode(node_input[0]); new_node_prim->set_instance_name(instance_name); new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); + if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) { + new_node_prim->set_attr("recompute", MakeValue(false)); + } new_node->set_scope(scope); node_input[0]->set_scope(scope); manager->Replace(pre_node, new_node); @@ -394,6 +400,18 @@ void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_p std::string op_name = (redistribution_oplist_ptr->first)[index].first; std::string instance_name_base = REDISTRIBUTION_OP; std::string instance_name = instance_name_base + "_" + CreateInstanceName(pre_node, index) + op_name; + auto prim_out = GetCNodePrimitive(node); + auto prim_in = GetCNodePrimitive(pre_node); + if (prim_out != nullptr && prim_in != nullptr) { + auto prim_out_attr = prim_out->attrs(); + auto prim_in_attr = prim_in->attrs(); + if (prim_out_attr.find(RECOMPUTE_COMM_OP) != prim_out_attr.end() && + prim_in_attr.find(RECOMPUTE_COMM_OP) != prim_in_attr.end() && + COMMUNICATION_OPS.find(op_name) != COMMUNICATION_OPS.end()) { + MS_LOG(INFO) << "The redistribution node would not be recomputed."; + instance_name = instance_name + "_" + NOT_RECOMPUTE; + } + } InsertNode(op, node, LongToSize(pos), target_node, func_graph, instance_name); if ((redistribution_oplist_ptr->second)[index].first) { target_node = node->input(LongToSize(pos)); @@ -443,12 +461,7 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr & } std::string GetPrimName(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!IsValueNode(node->input(0))) { - MS_LOG(EXCEPTION) << "The node is not a primitive"; - } - auto value_node = node->input(0)->cast(); - auto prim = GetValueNode(value_node); + auto prim = GetCNodePrimitive(node); MS_EXCEPTION_IF_NULL(prim); return prim->name(); } @@ -881,6 +894,11 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { PrimitivePtr prim = GetValueNode(replace_node->input(0)); PrimitivePtr origin_prim = GetValueNode(node->input(0)); SetUserAttrs(origin_prim->attrs(), prim); + if (origin_prim->attrs().find(RECOMPUTE_COMM_OP) != origin_prim->attrs().end() && + COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()) { + MS_LOG(INFO) << "The redistribution node in reshape would not be recomputed."; + prim->set_attr("recompute", MakeValue(false)); + } if (index == replace_op.size() - 1) { replace_node->set_user_data(node->user_data()); replace_node->set_primal_attrs(node->primal_attrs()); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 996cc11ba33..b1d3fce2ba8 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -187,6 +187,8 @@ std::string MirrorOpName(); CommInfo GetCommInfo(); +std::string GetPrimName(const CNodePtr &node); + void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages); } // namespace parallel } // namespace mindspore diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 8ab61f3a042..c522e0ce6af 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -21,6 +21,7 @@ from collections import OrderedDict import numpy +from mindspore._checkparam import args_type_check from mindspore import log as logger from mindspore.common.parameter import PARAMETER_NAME_DEFAULT from mindspore.common._decorator import deprecated @@ -85,6 +86,7 @@ class Cell(Cell_): self._cells = OrderedDict() self._params_list = OrderedDict() self._tensor_list = OrderedDict() + self._primitives = OrderedDict() self.training = False self.requires_grad = False self.pynative = False @@ -510,6 +512,7 @@ class Cell(Cell_): else: if isinstance(value, Primitive): value.set_prim_instance_name(name) + self._primitives[name] = value object.__setattr__(self, name, value) if name not in Cell.IGNORE_LIST: self._attr_synced = False @@ -1287,7 +1290,26 @@ class Cell(Cell_): elif not self._scope is None and self._scope.startswith(prefix): self._scope = self._scope[len(prefix):] - def recompute(self, mode=True, output_recompute=False): + def _mp_comm_recompute(self, mp_comm_recompute=True): + for _, value in self._primitives.items(): + if value: + value.add_prim_attr("recompute_comm_op", mp_comm_recompute) + for cell in self.cells(): + cell._mp_comm_recompute(mp_comm_recompute) + + def _recompute(self, mode=True, output_recompute=False): + if context.get_context("mode") == context.PYNATIVE_MODE: + raise TypeError("Recompute is not supported in pynative mode currently.") + Validator.check_bool(mode) + Validator.check_bool(output_recompute) + self._set_recompute_scope(mode) + if mode and not output_recompute: + self.add_flags(output_no_recompute=True) + for cell in self.cells(): + cell._recompute(mode, True) + + @args_type_check(mode=bool, output_recompute=bool, mp_comm_recompute=bool) + def recompute(self, **kwargs): """ Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive set recomputed feeds into some backward nodes for computing gradient, rather than storing the @@ -1304,16 +1326,25 @@ class Cell(Cell_): mode (bool): Specifies whether the cell is recomputed. Default: True. output_recompute (bool): Specifies whether the output of this cell is recomputed when the mode is true. Note that when the mode is false, this arg is not working. Default: False. + mp_comm_recompute (bool): Specifies whether the model parallel communication operators in the + cell is recomputed in auto parallel or semi auto parallel mode. Default: True. """ - if context.get_context("mode") == context.PYNATIVE_MODE: - raise TypeError("Recompute is not supported in pynative mode currently.") - Validator.check_bool(mode) - Validator.check_bool(output_recompute) - self._set_recompute_scope(mode) - if mode and not output_recompute: - self.add_flags(output_no_recompute=True) - for cell in self.cells(): - cell.recompute(mode, True) + if not kwargs: + self._recompute() + if 'mode' in kwargs.keys() or 'output_recompute' in kwargs.keys(): + mode = True + output_recompute = False + if 'mode' in kwargs.keys(): + mode = kwargs['mode'] + if 'output_recompute' in kwargs.keys(): + output_recompute = kwargs['output_recompute'] + self._recompute(mode, output_recompute) + if 'mp_comm_recompute' in kwargs.keys(): + self._mp_comm_recompute(kwargs['mp_comm_recompute']) + for key, _ in kwargs.items(): + if key not in ('mode', 'output_recompute', 'mp_comm_recompute'): + raise ValueError("Recompute keyword %s is not recognized!" % key) + def infer_param_pipeline_stage(self): """ diff --git a/tests/ut/python/optimizer/test_recompute.py b/tests/ut/python/optimizer/test_recompute.py index 28bbb38de8d..0e35c7f22a7 100644 --- a/tests/ut/python/optimizer/test_recompute.py +++ b/tests/ut/python/optimizer/test_recompute.py @@ -38,7 +38,7 @@ def test_set_recompute_true(): def test_set_recompute_false(): net = Net() - net.pool.recompute(False) + net.pool.recompute(mode=False) assert net.pool.get_scope() is None @@ -51,32 +51,32 @@ def test_set_recompute_true_twice(): def test_set_recompute_false_twice(): net = Net() - net.pool.recompute(False) - net.pool.recompute(False) + net.pool.recompute(mode=False) + net.pool.recompute(mode=False) assert net.pool.get_scope() is None def test_reset_recompute1(): net = Net() - net.pool.recompute(True) - net.pool.recompute(False) + net.pool.recompute(mode=True) + net.pool.recompute(mode=False) assert net.pool.get_scope() == "" def test_reset_recompute2(): net = Net() - net.pool.recompute(False) - net.pool.recompute(True) + net.pool.recompute(mode=False) + net.pool.recompute(mode=True) assert net.pool.get_scope() == recompute_prefix def test_set_scope_and_set_recompute_repeatedly(): net = Net() - net.pool.recompute(True) + net.pool.recompute(mode=True) assert net.pool.get_scope() == recompute_prefix - net.pool.recompute(False) + net.pool.recompute(mode=False) assert net.pool.get_scope() == "" - net.pool.recompute(True) + net.pool.recompute(mode=True) assert net.pool.get_scope() == recompute_prefix - net.pool.recompute(False) + net.pool.recompute(mode=False) assert net.pool.get_scope() == "" diff --git a/tests/ut/python/parallel/test_comm_not_recompute.py b/tests/ut/python/parallel/test_comm_not_recompute.py new file mode 100644 index 00000000000..9472a85f218 --- /dev/null +++ b/tests/ut/python/parallel/test_comm_not_recompute.py @@ -0,0 +1,87 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore.nn as nn +import mindspore as ms +from mindspore import Tensor, context, Parameter +from mindspore.common.api import _executor +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from mindspore.context import _Context +from ....train_step_wrap import train_step_with_loss_warp + +class MatMulCell(nn.Cell): + def __init__(self): + super(MatMulCell, self).__init__() + self.reshape = P.Reshape() + self.matmul0 = P.MatMul() + self.weight = Parameter(initializer("ones", [128, 64], ms.float32), name="weight") + self.relu = P.ReLU().shard(((1, 8),)) + def construct(self, x): + x = self.matmul0(x, self.weight) + x = self.reshape(x, (32, 128)) + x = self.relu(x) + return x + +class DenseMutMulNet(nn.Cell): + def __init__(self): + super(DenseMutMulNet, self).__init__() + self.fc1 = nn.Dense(128, 768, activation='relu') + self.fc2 = nn.Dense(128, 768, activation='relu') + self.fc3 = nn.Dense(128, 768, activation='relu') + self.fc4 = nn.Dense(768, 768, activation='relu') + self.fc1.recompute() + self.fc2.recompute() + self.fc3.recompute() + self.fc1.matmul.shard(((1, 1), (1, 8))) + self.fc2.matmul.shard(((1, 1), (1, 8))) + self.fc3.matmul.shard(((1, 1), (1, 8))) + self.relu4 = nn.ReLU() + self.relu5 = nn.ReLU() + self.transpose = P.Transpose() + self.matmul1 = P.MatMul() + self.matmul2 = P.MatMul() + self.matmul_cell = MatMulCell() + self.matmul_cell.recompute() + self.fc1.recompute(mp_comm_recompute=False) + self.fc2.recompute(mp_comm_recompute=False) + self.fc3.recompute(mp_comm_recompute=False) + self.matmul_cell.recompute(mp_comm_recompute=False) + + def construct(self, x): + x = self.matmul_cell(x) + q = self.fc1(x) + k = self.fc2(x) + v = self.fc3(x) + k = self.transpose(k, (1, 0)) + c = self.relu4(self.matmul1(q, k)) + s = self.relu5(self.matmul2(c, v)) + s = self.fc4(s) + return s + + +def test_dmnet_train_step(): + context.reset_auto_parallel_context() + _Context().set_backend_policy("vm") + context.set_context(mode=context.GRAPH_MODE) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8) + input_ = Tensor(np.ones([64, 128]).astype(np.float32) * 0.01) + label = Tensor(np.zeros([32, 768]).astype(np.float32)) + net = train_step_with_loss_warp(DenseMutMulNet()) + net.set_auto_parallel() + net.set_train() + _executor.compile(net, input_, label) + _Context().set_backend_policy("ge")