From 0d6f8e0619b8a8592906482c0205bf32fd0ec288 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Tue, 26 Oct 2021 10:00:41 +0800 Subject: [PATCH] dataset shard strategy fix --- mindspore/ccsrc/frontend/parallel/context.h | 5 +++++ .../frontend/parallel/ops_info/get_next_info.cc | 12 ++++++++---- .../ccsrc/frontend/parallel/ops_info/get_next_info.h | 1 + .../ccsrc/frontend/parallel/ops_info/ops_utils.h | 2 ++ .../parallel/ops_info/virtual_dataset_info.cc | 12 ++++++++---- mindspore/ccsrc/frontend/parallel/step_parallel.cc | 4 ++++ 6 files changed, 28 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index 00b1ef237c4..d447c4081fd 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -131,6 +131,10 @@ class ParallelContext { bool sharding_propagation() const { return sharding_propagation_; } void set_enable_all2all(const bool); bool enable_all2all() const { return enable_all2all_; } + void set_dataset_repeat_dim_right(const bool dataset_repeat_dim_right) { + dataset_repeat_dim_right_ = dataset_repeat_dim_right; + } + bool dataset_repeat_dim_right() const { return dataset_repeat_dim_right_; } void Reset(); void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph); @@ -173,6 +177,7 @@ class ParallelContext { // Enable AllToAll or not. If false, use AllGather and Split. bool enable_all2all_; std::vector> dataset_strategy_; + bool dataset_repeat_dim_right_ = false; }; } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc index 5e7f103304d..793df4bbe48 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc @@ -30,19 +30,19 @@ namespace mindspore { namespace parallel { Status GetNextInfo::InferTensorMap() { - auto slice_dim_iter = std::find(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), shard_num_); - if (slice_dim_iter == dev_matrix_shape_.end()) { + auto slice_dim_iter = std::find(dev_matrix_shape_origin_.begin(), dev_matrix_shape_origin_.end(), shard_num_); + if (slice_dim_iter == dev_matrix_shape_origin_.end()) { MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim."; return FAILED; } - size_t slice_dim = size_t(slice_dim_iter - dev_matrix_shape_.begin()); + size_t slice_dim = size_t(slice_dim_iter - dev_matrix_shape_origin_.begin()); for (size_t i = 0; i < dataset_strategy_.size(); i++) { Shape tensor_map_index; for (auto dim : dataset_strategy_[i]) { if (dim == 1) { tensor_map_index.push_back(MAP_NONE); } else if (dim == shard_num_) { - tensor_map_index.push_back(dev_matrix_shape_.size() - 1 - slice_dim); + tensor_map_index.push_back(dev_matrix_shape_origin_.size() - 1 - slice_dim); } else { MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support fully shard in one dim."; return FAILED; @@ -95,11 +95,15 @@ Status GetNextInfo::InferDevMatrixShape() { if (shard_num_iter != dev_matrix_shape_.end()) { shard_num_ = *shard_num_iter; } + dev_matrix_shape_origin_ = dev_matrix_shape_; return SUCCESS; } Status GetNextInfo::Init(const StrategyPtr &strategy) { repeated_num_in_dev_matrix_right_ = false; + if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) { + repeated_num_in_dev_matrix_right_ = true; + } if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed"; return FAILED; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h index 44da2c7ff06..7287735fcae 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h @@ -63,6 +63,7 @@ class GetNextInfo : public OperatorInfo { int64_t shard_num_ = 1; std::string shared_name_; Strategys dataset_strategy_; + Shape dev_matrix_shape_origin_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 95e061cae90..21042067689 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -226,6 +226,8 @@ constexpr char IS_TRAINING[] = "is_training"; constexpr char EPSILON[] = "epsilon"; constexpr char MOMENTUM[] = "momentum"; constexpr char DEVICE_NUM[] = "device_num"; +constexpr char REPEAT_DIM_DIRECT[] = "repeat_dim_direct"; +constexpr char RIGHT[] = "right"; // Operator constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc index 006117fb886..8be0a3658ff 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc @@ -94,12 +94,13 @@ Status VirtualDatasetInfo::InferMirrorOps() { return SUCCESS; } Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; } Status VirtualDatasetInfo::InferTensorMap() { - auto slice_dim_iter = std::find(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), shard_num_); - if (slice_dim_iter == dev_matrix_shape_.end()) { + auto dev_mat_origin = strategy_->GetInputDim()[max_size_strategy_dim_]; + auto slice_dim_iter = std::find(dev_mat_origin.begin(), dev_mat_origin.end(), shard_num_); + if (slice_dim_iter == dev_mat_origin.end()) { MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim."; return FAILED; } - size_t slice_dim = size_t(slice_dim_iter - dev_matrix_shape_.begin()); + size_t slice_dim = size_t(slice_dim_iter - dev_mat_origin.begin()); auto stra = strategy_->GetInputDim(); for (size_t i = 0; i < stra.size(); i++) { Shape tensor_map_index; @@ -107,7 +108,7 @@ Status VirtualDatasetInfo::InferTensorMap() { if (dim == 1) { tensor_map_index.push_back(MAP_NONE); } else if (dim == shard_num_) { - tensor_map_index.push_back(dev_matrix_shape_.size() - 1 - slice_dim); + tensor_map_index.push_back(dev_mat_origin.size() - 1 - slice_dim); } else { MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim."; return FAILED; @@ -123,6 +124,9 @@ Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; } Status VirtualDatasetInfo::Init(const StrategyPtr &strategy) { repeated_num_in_dev_matrix_right_ = false; + if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) { + repeated_num_in_dev_matrix_right_ = true; + } if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 13392282b77..71dbdf86541 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1762,6 +1762,10 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { ValueTuplePtr strategy = std::make_shared(elements); attrs_temp[STRATEGY] = strategy; (void)prim->SetAttrs(attrs_temp); + if (prim->HasAttr(REPEAT_DIM_DIRECT) && GetValue(prim->GetAttr(REPEAT_DIM_DIRECT)) == RIGHT) { + ParallelContext::GetInstance()->set_dataset_repeat_dim_right(true); + MS_LOG(INFO) << "dataset repeat dim is right"; + } return; } int64_t dev_num;