From 188d39da83a49a2cc3f45fd741c1e283add93ddb Mon Sep 17 00:00:00 2001 From: yao_yf Date: Sat, 6 Nov 2021 16:12:23 +0800 Subject: [PATCH] slice_activation_in_recompute slice recompute activation --- .../optimizer/irpass/recompute_prepare.h | 8 + .../slice_activation_in_recompute.cc | 309 ++++++++++++++++++ .../optimizer/slice_activation_in_recompute.h | 28 ++ .../tensor_layout/construct_operator.h | 2 +- mindspore/ccsrc/pipeline/jit/pass.cc | 10 +- mindspore/ccsrc/utils/utils.h | 1 + mindspore/nn/cell.py | 14 +- mindspore/parallel/nn/op_parallel_config.py | 1 - mindspore/parallel/nn/transformer.py | 99 +++++- .../parallel/test_comm_not_recompute.py | 51 ++- .../parallel/test_parallel_transformer.py | 4 +- 11 files changed, 501 insertions(+), 26 deletions(-) create mode 100644 mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.cc create mode 100644 mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.h diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/recompute_prepare.h b/mindspore/ccsrc/frontend/optimizer/irpass/recompute_prepare.h index 0027f688124..88e830a4dc1 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/recompute_prepare.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/recompute_prepare.h @@ -21,6 +21,7 @@ #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/anf_visitor.h" +#include "frontend/parallel/context.h" #include "ir/func_graph.h" namespace mindspore { @@ -47,6 +48,13 @@ class SetCellOutputNoRecompute : public AnfVisitor { for (const auto &real_output : real_outputs) { // Set the attr of cnode in case of shared primitives. real_output->AddAttr(kAttrRecompute, MakeValue(false)); + if (parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::SEMI_AUTO_PARALLEL || + parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::AUTO_PARALLEL) { + auto prim = GetCNodePrimitive(real_output); + if (prim->HasAttr(kAttrSliceActivation) && GetValue(prim->GetAttr(kAttrSliceActivation))) { + real_output->AddAttr(kAttrSliceActivation, MakeValue(true)); + } + } } } fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE); diff --git a/mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.cc b/mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.cc new file mode 100644 index 00000000000..11463a30952 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.cc @@ -0,0 +1,309 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/slice_activation_in_recompute.h" +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/base/core_ops.h" +#include "utils/utils.h" +#include "frontend/parallel/tensor_layout/construct_operator.h" +#include "frontend/parallel/step_parallel.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr auto kGradientsFlag = "Gradients"; +const int64_t max_loop_size = 100; +bool IsBpropNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + return node->fullname_with_scope().find(kGradientsFlag) == 0; +} + +CNodePtr CreateStridedSliceCNode(const parallel::Shape &begin, const parallel::Shape &end, + const parallel::Shape &strides, const AnfNodePtr &node) { + auto slice_op = parallel::CreateStridedSliceOp(0, begin, end, strides); + auto slice_input = parallel::CreateInput(slice_op, node, parallel::STRIDEDSLICE); + auto func_graph = node->func_graph(); + CNodePtr new_node = func_graph->NewCNode(slice_input); + return new_node; +} + +CNodePtr CreateAllGatherCNode(const AnfNodePtr &node, std::string group) { + auto op = parallel::CreateAllGatherOp(group); + auto allgather_input = parallel::CreateInput(op, node, "recompute_slice_allgather"); + auto func_graph = node->func_graph(); + CNodePtr new_node = func_graph->NewCNode(allgather_input); + return new_node; +} + +parallel::Group InferRepeatedRankList(const CNodePtr &cnode) { + OperatorInfoPtr operator_info = cnode->user_data(); + std::vector output_info = operator_info->outputs_tensor_info(); + if (output_info.size() != 1) { + MS_LOG(EXCEPTION) << "The output_info size is wrong, node is" << cnode->DebugString(); + } + auto tensor_layout = output_info[0].tensor_layout(); + auto tensor_map = tensor_layout.origin_tensor_map(); + std::vector groups; + operator_info->CreateGroupByTensorMap(tensor_map.array(), &groups); + return groups[0]; +} + +bool IsDuplicateNode(const AnfNodePtr &node) { + if (!node->isa()) { + return false; + } + if (node->cast()->HasAttr(kAttrDuplicated)) { + return true; + } + if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) { + auto manager = node->func_graph()->manager(); + auto node_users = manager->node_users()[node]; + return std::any_of(node_users.begin(), node_users.end(), + [](auto node_user) { return IsDuplicateNode(node_user.first); }); + } + return false; +} + +CNodePtr BpropInput(const CNodePtr cnode) { + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + if (IsBpropNode(cnode->input(i))) { + return cnode->input(i)->cast(); + } + } + return nullptr; +} + +std::vector NextBpropNodes(const CNodePtr cnode) { + std::vector next_brop_nodes; + auto manager = cnode->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto cnode_users = manager->node_users()[cnode]; + std::queue cnode_queue; + for (auto &node : cnode_users) { + if (!node.first->isa()) { + continue; + } + cnode_queue.push(node.first->cast()); + } + int64_t loop_size = 0; + while (!cnode_queue.empty() && loop_size < max_loop_size) { + auto cur_cnode = cnode_queue.front(); + cnode_queue.pop(); + if (IsBpropNode(cur_cnode)) { + next_brop_nodes.push_back(cur_cnode); + continue; + } + auto cur_cnode_users = manager->node_users()[cur_cnode]; + for (auto &node : cur_cnode_users) { + if (!node.first->isa()) { + continue; + } + cnode_queue.push(node.first->cast()); + } + loop_size++; + } + if (loop_size < max_loop_size) { + return next_brop_nodes; + } + return {}; +} + +void GroupingNextNodes(const CNodePtr &node, std::vector, int>> *duplicate_users, + std::vector, int>> *forward_users) { + auto manager = node->func_graph()->manager(); + auto root_node = node; + auto node_users = manager->node_users()[root_node]; + for (auto node_user : node_users) { + if (IsDuplicateNode(node_user.first)) { + duplicate_users->push_back(node_user); + } else { + forward_users->push_back(node_user); + } + } +} + +void InsertSliceAllGatherNode(const std::vector, int>> &node_users, + const std::pair, int> &forward_node_user, + const std::shared_ptr &node, std::vector *slice_allgathers, + int64_t recompute_order_id) { + auto manager = node->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto output_shape = node->abstract()->BuildShape(); + std::vector out_shape_element = output_shape->cast()->shape(); + if (out_shape_element.empty()) { + return; + } + int64_t global_rank_id = parallel::g_device_manager->global_rank(); + int64_t stage_num = parallel::g_device_manager->stage_num(); + int64_t device_num = parallel::g_device_manager->DeviceNum(); + int64_t stage_device_num = device_num / stage_num; + int64_t local_rank_id = global_rank_id % stage_device_num; + auto group = InferRepeatedRankList(node); + if (out_shape_element[0] % group.GetDevNum() != 0) { + MS_LOG(WARNING) << "The output_shape first dim:" << out_shape_element[0] + << " cannot be divisible by the repeated size: " << group.GetDevNum() + << "The slice would not activate to this node: " << node->DebugString(); + return; + } + int64_t group_deivce_num = group.GetDevNum(); + std::vector slice_begin(out_shape_element.size(), 0); + slice_begin[0] = (local_rank_id % group_deivce_num) * (out_shape_element[0] / group_deivce_num); + std::vector slice_end = out_shape_element; + slice_end[0] = (local_rank_id % group_deivce_num + 1) * (out_shape_element[0] / group_deivce_num); + std::vector slice_strides(out_shape_element.size(), 1); + CNodePtr slice_cnode = CreateStridedSliceCNode(slice_begin, slice_end, slice_strides, node); + slice_cnode->set_abstract(node->abstract()->Clone()); + std::vector slice_shape = out_shape_element; + slice_shape[0] = out_shape_element[0] / group_deivce_num; + std::shared_ptr slice_base_shape = std::make_shared(slice_shape); + slice_cnode->abstract()->set_shape(slice_base_shape); + for (auto &node_user : node_users) { + manager->SetEdge(node_user.first, node_user.second, slice_cnode); + } + + CNodePtr allgather_cnode = CreateAllGatherCNode(slice_cnode, group.name()); + allgather_cnode->set_abstract(node->abstract()->Clone()); + allgather_cnode->AddAttr("recompute_order", MakeValue(recompute_order_id)); + if (node->HasPrimalAttr(parallel::MICRO)) { + allgather_cnode->AddPrimalAttr(parallel::MICRO, node->GetPrimalAttr(parallel::MICRO)); + } + manager->Replace(slice_cnode, allgather_cnode); + slice_allgathers->push_back(allgather_cnode); + + std::vector depend_inputs{NewValueNode(prim::kPrimDepend), forward_node_user.first, slice_cnode}; + auto depend_node = node->func_graph()->NewCNode(depend_inputs); + depend_node->set_abstract(forward_node_user.first->abstract()->Clone()); + depend_node->AddAttr("slice_forward_depend", MakeValue(true)); + MS_EXCEPTION_IF_NULL(depend_node); + manager->Replace(forward_node_user.first, depend_node); +} + +void InsertAllGatherDepend(const FuncGraphPtr &graph, const std::vector &slice_allgathers) { + auto manager = graph->manager(); + auto last_allgather = slice_allgathers.back(); + auto next_cnodes = NextBpropNodes(last_allgather); + CNodePtr next_cnode = nullptr; + if (!next_cnodes.empty()) { + MS_LOG(INFO) << "The next_cnodes is not empty."; + std::list orders = graph->GetOrderedCnodes(); + for (auto &cnode : orders) { + if (std::find(next_cnodes.begin(), next_cnodes.end(), cnode) != next_cnodes.end()) { + next_cnode = cnode; + break; + } + } + } + + for (size_t i = slice_allgathers.size() - 1; i > 0; --i) { + std::vector depend_inputs{NewValueNode(prim::kPrimDepend), slice_allgathers[i - 1]->input(1), + slice_allgathers[i]}; + auto depend_node = graph->NewCNode(depend_inputs); + MS_EXCEPTION_IF_NULL(depend_node); + depend_node->set_abstract(slice_allgathers[i - 1]->input(1)->abstract()->Clone()); + depend_node->AddAttr("slice_allgather_depend", MakeValue(i)); + manager->SetEdge(slice_allgathers[i - 1], 1, depend_node); + } + + if (next_cnode == nullptr) { + MS_LOG(WARNING) << "cannot find the bprop node in 100 loop"; + return; + } + auto allgather_depend_node = BpropInput(next_cnode); + if (allgather_depend_node == nullptr) { + MS_LOG(WARNING) << "cannot find the bprob input for allgather to depend."; + return; + } + MS_LOG(INFO) << "Insert depend for last slice allgather. The depend node is: " + << allgather_depend_node->DebugString(); + std::vector depend_inputs{NewValueNode(prim::kPrimDepend), last_allgather->input(1), + allgather_depend_node}; + auto depend_node = graph->NewCNode(depend_inputs); + MS_EXCEPTION_IF_NULL(depend_node); + depend_node->set_abstract(last_allgather->input(1)->abstract()->Clone()); + depend_node->AddAttr("last_slice_allgather_depend", MakeValue(true)); + manager->SetEdge(last_allgather, 1, depend_node); +} +} // namespace + +void SliceRecomputedActivationNodes(const FuncGraphPtr &graph) { + if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::SEMI_AUTO_PARALLEL && + parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::AUTO_PARALLEL) { + return; + } + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::list orders = graph->GetOrderedCnodes(); + std::vector origin_nodes_topological(orders.begin(), orders.end()); + std::vector slice_allgathers; + int64_t recompute_order_id = 0; + for (auto &node : origin_nodes_topological) { + if (!node->HasAttr(kAttrSliceActivation) || IsPrimitiveCNode(node, prim::kPrimTranspose) || + !node->has_user_data()) { + continue; + } + auto node_users = manager->node_users()[node]; + std::vector, int>> duplicate_users; + std::vector, int>> forward_users; + GroupingNextNodes(node, &duplicate_users, &forward_users); + if (duplicate_users.empty() || forward_users.empty()) { + continue; + } + InsertSliceAllGatherNode(duplicate_users, forward_users[0], node, &slice_allgathers, recompute_order_id); + recompute_order_id++; + } + if (slice_allgathers.size() == 0) { + return; + } + if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) { + int64_t current_micro = -1; + std::vector stage_slice_allgathers; + for (auto &slice_allgather_node : slice_allgathers) { + if (!slice_allgather_node->HasPrimalAttr(parallel::MICRO)) { + MS_LOG(EXCEPTION) << "In pipeline parallel mode, cannot find 'micro' attributes in node."; + } + int64_t micro = GetValue(slice_allgather_node->GetPrimalAttr(parallel::MICRO)); + if (micro > current_micro) { + if (current_micro != -1) { + MS_LOG(INFO) << "Insert allgather depends, micro is: " << current_micro; + InsertAllGatherDepend(graph, stage_slice_allgathers); + } + stage_slice_allgathers.clear(); + stage_slice_allgathers.push_back(slice_allgather_node); + current_micro = micro; + } else if (micro == current_micro) { + stage_slice_allgathers.push_back(slice_allgather_node); + } else if (current_micro != -1) { + MS_LOG(EXCEPTION) << "The micro number dose not match the execution orders in pipeline parallel"; + } + } + MS_LOG(INFO) << "Insert last stage allgather depends, micro is: " << current_micro; + InsertAllGatherDepend(graph, stage_slice_allgathers); + } else { + InsertAllGatherDepend(graph, slice_allgathers); + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.h b/mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.h new file mode 100644 index 00000000000..4cf86d4d7af --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_SLICE_ACTIVATION_IN_RECOMPUTE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_SLICE_ACTIVATION_IN_RECOMPUTE_H_ + +#include "ir/anf.h" + +namespace mindspore { +namespace opt { +// Automatically insert duplicated recomputed nodes. +void SliceRecomputedActivationNodes(const FuncGraphPtr &graph); +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_SLICE_ACTIVATION_IN_RECOMPUTE_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h index 70126960790..9d913c02d58 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h @@ -28,7 +28,7 @@ namespace mindspore { namespace parallel { using Args = std::vector; - +Operator CreateStridedSliceOp(int64_t value, const Shape &begin, const Shape &end, const Shape &strides); class ConstructOperator { public: const int64_t DEFAULT = 0; diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 1d9e4317ed1..3275f4c6504 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -40,6 +40,7 @@ #include "frontend/parallel/cache_embedding/cache_embedding.h" #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" #include "frontend/optimizer/recompute.h" +#include "frontend/optimizer/slice_activation_in_recompute.h" #include "utils/log_adapter.h" #include "pipeline/jit/pipeline_split.h" #include "pipeline/pynative/pynative_execute.h" @@ -594,6 +595,12 @@ bool AddRecomputationPass(const ResourcePtr &res) { return true; } +bool SliceRecomputeActivationPass(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res); + opt::SliceRecomputedActivationNodes(res->func_graph()); + return true; +} + bool AddCacheEmbeddingPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res); #if ((defined ENABLE_CPU) && (!defined _WIN32)) @@ -734,7 +741,8 @@ std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStru {"tuple_transform", OptPassTransformGraphGroup}, {"add_cache_embedding", AddCacheEmbeddingPass}, {"add_recomputation", AddRecomputationPass}, - {"cse_after_recomputation", OptAfterRecomputeGroup}}; + {"cse_after_recomputation", OptAfterRecomputeGroup}, + {"slice_recompute_activation", SliceRecomputeActivationPass}}; std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_a", OptPassAGroup}, diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index a52c581e36b..e0397f07cc1 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -471,6 +471,7 @@ constexpr auto kAttrPadding = "padding"; constexpr auto kAttrNonTask = "non_task"; constexpr auto kAttrIsGrad = "is_grad"; constexpr auto kAttrRecompute = "recompute"; +constexpr auto kAttrSliceActivation = "slice_activation"; constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; constexpr auto kAttrParallelFusionType = "parallel_fusion_type"; diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 8b047577bad..ae9e00bad71 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1499,6 +1499,16 @@ class Cell(Cell_): for param in self.trainable_params(): param.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute + def _recompute_slice_activation(self, slice_activation=False): + """ + Slice the cell output which would remains in memory. + """ + for _, value in self._primitives.items(): + if value: + value.add_prim_attr("slice_activation", slice_activation) + for cell in self.cells(): + cell._recompute_slice_activation(slice_activation) + def _recompute(self, mode=True, output_recompute=False): """ Set the cell recomputed. @@ -1554,9 +1564,11 @@ class Cell(Cell_): raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard " "are not support recomputation in pipeline parallel.") self._parallel_optimizer_comm_recompute(kwargs['parallel_optimizer_comm_recompute']) + if 'recompute_slice_activation' in kwargs.keys(): + self._recompute_slice_activation(kwargs['recompute_slice_activation']) for key, _ in kwargs.items(): - if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute'): + if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'): raise ValueError("Recompute keyword %s is not recognized!" % key) def infer_param_pipeline_stage(self): diff --git a/mindspore/parallel/nn/op_parallel_config.py b/mindspore/parallel/nn/op_parallel_config.py index 98930939b79..997543910e4 100644 --- a/mindspore/parallel/nn/op_parallel_config.py +++ b/mindspore/parallel/nn/op_parallel_config.py @@ -78,7 +78,6 @@ class OpParallelConfig(_Config): Validator.check_positive_int(value, "model_parallel") self._model_parallel = value - class _PipeLineConfig(_Config): r""" PPConfig for the setting data parallel, model parallel diff --git a/mindspore/parallel/nn/transformer.py b/mindspore/parallel/nn/transformer.py index 9833e9cfa50..d7cb0b5312a 100644 --- a/mindspore/parallel/nn/transformer.py +++ b/mindspore/parallel/nn/transformer.py @@ -48,7 +48,8 @@ __all__ = [ "TransformerDecoderLayer", "Transformer", "TransformerOpParallelConfig", - "EmbeddingOpParallelConfig"] + "EmbeddingOpParallelConfig", + "TransformerRecomputeConfig"] class EmbeddingOpParallelConfig(_Config): @@ -112,6 +113,76 @@ class EmbeddingOpParallelConfig(_Config): return self._dp_mp_config +class TransformerRecomputeConfig(_Config): + r""" + TransformerRecomputeConfig for the setting recompute attributes for encoder/decoder layers. + + Args: + recompute (bool): Enable recomputation of the transformer block or not. Default: False. + parallel_optimizer_comm_recompute (bool): The model parallel way. Default: 1 + parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers + introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. + Default: False. + mp_comm_recompute (bool): Specifies whether the model parallel communication operators + in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True. + recompute_slice_activation (bool): Slice the cell output which would remains in memory. Default: False. + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> config=TransformerRecomputeConfig(recompute=True, parallel_optimizer_comm_recompute=True, + >>> mp_comm_recompute=True, recompute_slice_activation=True) + """ + + def __init__(self, recompute=False, parallel_optimizer_comm_recompute=False, + mp_comm_recompute=True, recompute_slice_activation=False): + Validator.check_bool(recompute, "recompute") + Validator.check_bool(parallel_optimizer_comm_recompute, "parallel_optimizer_comm_recompute") + Validator.check_bool(mp_comm_recompute, "mp_comm_recompute") + Validator.check_bool(recompute_slice_activation, "recompute_slice_activation") + self._recompute = recompute + self._parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute + self._mp_comm_recompute = mp_comm_recompute + self._recompute_slice_activation = recompute_slice_activation + + @property + def recompute(self): + return self._recompute + + @recompute.setter + def recompute(self, value): + Validator.check_bool(value, "recompute") + self._recompute = value + + @property + def parallel_optimizer_comm_recompute(self): + return self._parallel_optimizer_comm_recompute + + @parallel_optimizer_comm_recompute.setter + def parallel_optimizer_comm_recompute(self, value): + Validator.check_bool(value, "parallel_optimizer_comm_recompute") + self._parallel_optimizer_comm_recompute = value + + @property + def mp_comm_recompute(self): + return self._mp_comm_recompute + + @mp_comm_recompute.setter + def mp_comm_recompute(self, value): + Validator.check_bool(value, "mp_comm_recompute") + self._mp_comm_recompute = value + + @property + def recompute_slice_activation(self): + return self._recompute_slice_activation + + @recompute_slice_activation.setter + def recompute_slice_activation(self, value): + Validator.check_bool(value, "recompute_slice_activation") + self._recompute_slice_activation = value + +_DEFALUT_TRANSFORMER_RECOMPUTE_CONFIG = TransformerRecomputeConfig() + class TransformerOpParallelConfig(_Config): r""" TransformerOpParallelConfig for the setting global data parallel, model parallel and fusion group. @@ -131,17 +202,21 @@ class TransformerOpParallelConfig(_Config): micro_batch_num (int): The microe size of the batches for the pipeline training. Default: 1. optimizer_shard (bool): Whether to enable optimizer shard. Default False. gradient_aggregation_group (int): The fusion group size of the optimizer state sharding. Default: 4. - recompute (bool): Enable recomputation of the transformer block or not. Default: False. + recompute (Union[TransformerRecomputeConfig, bool]): The configuration of recomputation for + the transformer block. Default: The default configuration of TransformerRecomputeConfig. vocab_emb_dp (bool): Shard embedding in model parallel or data parallel. Default: True. Supported Platforms: ``Ascend`` ``GPU`` Examples: - >>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1) + >>> recompute_config=TransformerRecomputeConfig(recompute=True, parallel_optimizer_comm_recompute=True, + >>> mp_comm_recompute=True, recompute_slice_activation=True) + >>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config) """ - def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False, + def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, + recompute=_DEFALUT_TRANSFORMER_RECOMPUTE_CONFIG, optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True): self.recompute = recompute self.optimizer_shard = optimizer_shard @@ -156,7 +231,10 @@ class TransformerOpParallelConfig(_Config): @recompute.setter def recompute(self, value): - Validator.check_bool(value, "recompute") + if not isinstance(value, TransformerRecomputeConfig) and not isinstance(value, bool): + raise TypeError(f"recompute should be a TransformerRecomputeConfig/bool, but got {type(value).__name__}.") + if isinstance(value, bool): + logger.warning(f"TransformerRecomputeConfig is recommended as the recompute configuration type.") self._recompute = value @property @@ -1726,8 +1804,15 @@ def _get_lambda_func(total_layer=None): dis = max(int(layers / parallel_config.gradient_aggregation_group), 1) network.set_comm_fusion(int((layer_id + offset) / dis) + 1) # Used for enabling recomputation of the block - if parallel_config.recompute: - network.recompute() + if isinstance(parallel_config.recompute, bool): + if parallel_config.recompute: + network.recompute() + else: + if parallel_config.recompute.recompute: + network.recompute(parallel_optimizer_comm_recompute= + parallel_config.recompute.parallel_optimizer_comm_recompute, + mp_comm_recompute=parallel_config.recompute.mp_comm_recompute, + recompute_slice_activation=parallel_config.recompute.recompute_slice_activation) return _set_parallel_configure_for_layer diff --git a/tests/ut/python/parallel/test_comm_not_recompute.py b/tests/ut/python/parallel/test_comm_not_recompute.py index d7a1c4891df..96c8f902db8 100644 --- a/tests/ut/python/parallel/test_comm_not_recompute.py +++ b/tests/ut/python/parallel/test_comm_not_recompute.py @@ -27,8 +27,8 @@ class MatMulCell(nn.Cell): def __init__(self): super(MatMulCell, self).__init__() self.reshape = P.Reshape() - self.matmul0 = P.MatMul() - self.weight = Parameter(initializer("ones", [128, 64], ms.float32), name="weight") + self.matmul0 = P.MatMul(transpose_b=True) + self.weight = Parameter(initializer("ones", [64, 128], ms.float32), name="weight") self.relu = P.ReLU().shard(((1, 8),)) def construct(self, x): x = self.matmul0(x, self.weight) @@ -37,25 +37,26 @@ class MatMulCell(nn.Cell): return x class DenseMutMulNet(nn.Cell): - def __init__(self): + def __init__(self, mp_comm_recompute=True, recompute_slice_activation=False): super(DenseMutMulNet, self).__init__() self.fc1 = nn.Dense(128, 768, activation='relu') self.fc2 = nn.Dense(128, 768, activation='relu') self.fc3 = nn.Dense(128, 768, activation='relu') self.fc4 = nn.Dense(768, 768, activation='relu') - self.fc1.matmul.shard(((1, 1), (1, 8))) - self.fc2.matmul.shard(((1, 1), (1, 8))) - self.fc3.matmul.shard(((1, 1), (1, 8))) + self.fc1.matmul.shard(((1, 1), (8, 1))) + self.fc2.matmul.shard(((1, 1), (8, 1))) + self.fc3.matmul.shard(((1, 1), (8, 1))) self.relu4 = nn.ReLU() self.relu5 = nn.ReLU() self.transpose = P.Transpose() self.matmul1 = P.MatMul() self.matmul2 = P.MatMul() self.matmul_cell = MatMulCell() - self.fc1.recompute(mp_comm_recompute=False) - self.fc2.recompute(mp_comm_recompute=False) - self.fc3.recompute(mp_comm_recompute=False) - self.matmul_cell.recompute(mp_comm_recompute=False) + self.fc1.recompute(mp_comm_recompute=mp_comm_recompute, recompute_slice_activation=recompute_slice_activation) + self.fc2.recompute(mp_comm_recompute=mp_comm_recompute, recompute_slice_activation=recompute_slice_activation) + self.fc3.recompute(mp_comm_recompute=mp_comm_recompute, recompute_slice_activation=recompute_slice_activation) + self.matmul_cell.recompute(mp_comm_recompute=mp_comm_recompute, + recompute_slice_activation=recompute_slice_activation) def construct(self, x): x = self.matmul_cell(x) @@ -68,16 +69,40 @@ class DenseMutMulNet(nn.Cell): s = self.fc4(s) return s - -def test_dmnet_train_step(): +def compile_net(mp_comm_recompute, recompute_slice_activation): context.reset_auto_parallel_context() _Context().set_backend_policy("vm") context.set_context(mode=context.GRAPH_MODE) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8) input_ = Tensor(np.ones([64, 128]).astype(np.float32) * 0.01) label = Tensor(np.zeros([32, 768]).astype(np.float32)) - net = train_step_with_loss_warp(DenseMutMulNet()) + net = train_step_with_loss_warp(DenseMutMulNet(mp_comm_recompute=mp_comm_recompute, + recompute_slice_activation=recompute_slice_activation)) net.set_auto_parallel() net.set_train() _cell_graph_executor.compile(net, input_, label) _Context().set_backend_policy("ge") + +def test_dmnet_train_step_mp_recompute(): + """ + Feature: test recompute interface. + Description: test model parallel communication not recompute. + Expectation: compile without error. + """ + compile_net(False, False) + +def test_dmnet_train_step_recompute_activation_slice(): + """ + Feature: test recompute interface. + Description: test slicing recompute cell output. + Expectation: compile without error. + """ + compile_net(True, True) + +def test_dmnet_train_step_mp_recompute_recompute_activation_slice(): + """ + Feature: test recompute interface. + Description: test model parallel communication not recompute and slicing recompute cell output. + Expectation: compile without error. + """ + compile_net(False, True) diff --git a/tests/ut/python/parallel/test_parallel_transformer.py b/tests/ut/python/parallel/test_parallel_transformer.py index ad34d68c652..6a03b0559c5 100644 --- a/tests/ut/python/parallel/test_parallel_transformer.py +++ b/tests/ut/python/parallel/test_parallel_transformer.py @@ -652,9 +652,9 @@ def test_transformer_parallel_config(): with pytest.raises(TypeError): parallel_test_config.recompute = 1 - parallel_test_config.recompute = False + parallel_test_config.recompute.recompute = False - assert not parallel_test_config.recompute + assert not parallel_test_config.recompute.recompute def test_parallel_config():