From facb0995f4f9c69e8adbdddc4f822aea0b5bdc1b Mon Sep 17 00:00:00 2001 From: yao_yf Date: Fri, 26 Nov 2021 09:22:35 +0800 Subject: [PATCH] slice recompute state fix --- .../optimizer/slice_activation_in_recompute.cc | 10 +++++++--- .../parallel_strategy_checkpoint.cc | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.cc b/mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.cc index 11463a30952..231e1b91d89 100644 --- a/mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.cc +++ b/mindspore/ccsrc/frontend/optimizer/slice_activation_in_recompute.cc @@ -57,7 +57,7 @@ CNodePtr CreateAllGatherCNode(const AnfNodePtr &node, std::string group) { return new_node; } -parallel::Group InferRepeatedRankList(const CNodePtr &cnode) { +std::vector InferRepeatedRankList(const CNodePtr &cnode) { OperatorInfoPtr operator_info = cnode->user_data(); std::vector output_info = operator_info->outputs_tensor_info(); if (output_info.size() != 1) { @@ -67,7 +67,7 @@ parallel::Group InferRepeatedRankList(const CNodePtr &cnode) { auto tensor_map = tensor_layout.origin_tensor_map(); std::vector groups; operator_info->CreateGroupByTensorMap(tensor_map.array(), &groups); - return groups[0]; + return groups; } bool IsDuplicateNode(const AnfNodePtr &node) { @@ -160,7 +160,11 @@ void InsertSliceAllGatherNode(const std::vectorDeviceNum(); int64_t stage_device_num = device_num / stage_num; int64_t local_rank_id = global_rank_id % stage_device_num; - auto group = InferRepeatedRankList(node); + auto groups = InferRepeatedRankList(node); + if (groups.empty()) { + return; + } + auto group = groups[0]; if (out_shape_element[0] % group.GetDevNum() != 0) { MS_LOG(WARNING) << "The output_shape first dim:" << out_shape_element[0] << " cannot be divisible by the repeated size: " << group.GetDevNum() diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index a14da663e69..65281afda55 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -45,7 +45,7 @@ bool StrategyCheckpoint::CheckPath(const std::string path) const { MS_LOG(ERROR) << "The checkpoit path " << path << " is too long"; return false; } - auto realpath = Common::CreatePrefixPath(path); + auto realpath = Common::CreatePrefixPath(path, true); if (!realpath.has_value()) { MS_LOG(ERROR) << "Get real path failed, path=" << path; return false;