!25409 dataset_strategy_repeat_cal_fix

Merge pull request !25409 from yao_yf/dataset_strategy_repeat_cal_fix
This commit is contained in:
i-robot 2021-10-29 01:21:29 +00:00 committed by Gitee
commit 0effb635e9
6 changed files with 28 additions and 8 deletions

View File

@ -131,6 +131,10 @@ class ParallelContext {
bool sharding_propagation() const { return sharding_propagation_; }
void set_enable_all2all(const bool);
bool enable_all2all() const { return enable_all2all_; }
void set_dataset_repeat_dim_right(const bool dataset_repeat_dim_right) {
dataset_repeat_dim_right_ = dataset_repeat_dim_right;
}
bool dataset_repeat_dim_right() const { return dataset_repeat_dim_right_; }
void Reset();
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
@ -173,6 +177,7 @@ class ParallelContext {
// Enable AllToAll or not. If false, use AllGather and Split.
bool enable_all2all_;
std::vector<std::vector<int64_t>> dataset_strategy_;
bool dataset_repeat_dim_right_ = false;
};
} // namespace parallel

View File

@ -30,19 +30,19 @@
namespace mindspore {
namespace parallel {
Status GetNextInfo::InferTensorMap() {
auto slice_dim_iter = std::find(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), shard_num_);
if (slice_dim_iter == dev_matrix_shape_.end()) {
auto slice_dim_iter = std::find(dev_matrix_shape_origin_.begin(), dev_matrix_shape_origin_.end(), shard_num_);
if (slice_dim_iter == dev_matrix_shape_origin_.end()) {
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim.";
return FAILED;
}
size_t slice_dim = size_t(slice_dim_iter - dev_matrix_shape_.begin());
size_t slice_dim = size_t(slice_dim_iter - dev_matrix_shape_origin_.begin());
for (size_t i = 0; i < dataset_strategy_.size(); i++) {
Shape tensor_map_index;
for (auto dim : dataset_strategy_[i]) {
if (dim == 1) {
tensor_map_index.push_back(MAP_NONE);
} else if (dim == shard_num_) {
tensor_map_index.push_back(dev_matrix_shape_.size() - 1 - slice_dim);
tensor_map_index.push_back(dev_matrix_shape_origin_.size() - 1 - slice_dim);
} else {
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support fully shard in one dim.";
return FAILED;
@ -95,11 +95,15 @@ Status GetNextInfo::InferDevMatrixShape() {
if (shard_num_iter != dev_matrix_shape_.end()) {
shard_num_ = *shard_num_iter;
}
dev_matrix_shape_origin_ = dev_matrix_shape_;
return SUCCESS;
}
Status GetNextInfo::Init(const StrategyPtr &strategy) {
repeated_num_in_dev_matrix_right_ = false;
if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) {
repeated_num_in_dev_matrix_right_ = true;
}
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed";
return FAILED;

View File

@ -63,6 +63,7 @@ class GetNextInfo : public OperatorInfo {
int64_t shard_num_ = 1;
std::string shared_name_;
Strategys dataset_strategy_;
Shape dev_matrix_shape_origin_;
};
} // namespace parallel
} // namespace mindspore

View File

@ -226,6 +226,8 @@ constexpr char IS_TRAINING[] = "is_training";
constexpr char EPSILON[] = "epsilon";
constexpr char MOMENTUM[] = "momentum";
constexpr char DEVICE_NUM[] = "device_num";
constexpr char REPEAT_DIM_DIRECT[] = "repeat_dim_direct";
constexpr char RIGHT[] = "right";
// Operator
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";

View File

@ -94,12 +94,13 @@ Status VirtualDatasetInfo::InferMirrorOps() { return SUCCESS; }
Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; }
Status VirtualDatasetInfo::InferTensorMap() {
auto slice_dim_iter = std::find(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), shard_num_);
if (slice_dim_iter == dev_matrix_shape_.end()) {
auto dev_mat_origin = strategy_->GetInputDim()[max_size_strategy_dim_];
auto slice_dim_iter = std::find(dev_mat_origin.begin(), dev_mat_origin.end(), shard_num_);
if (slice_dim_iter == dev_mat_origin.end()) {
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim.";
return FAILED;
}
size_t slice_dim = size_t(slice_dim_iter - dev_matrix_shape_.begin());
size_t slice_dim = size_t(slice_dim_iter - dev_mat_origin.begin());
auto stra = strategy_->GetInputDim();
for (size_t i = 0; i < stra.size(); i++) {
Shape tensor_map_index;
@ -107,7 +108,7 @@ Status VirtualDatasetInfo::InferTensorMap() {
if (dim == 1) {
tensor_map_index.push_back(MAP_NONE);
} else if (dim == shard_num_) {
tensor_map_index.push_back(dev_matrix_shape_.size() - 1 - slice_dim);
tensor_map_index.push_back(dev_mat_origin.size() - 1 - slice_dim);
} else {
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim.";
return FAILED;
@ -123,6 +124,9 @@ Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; }
Status VirtualDatasetInfo::Init(const StrategyPtr &strategy) {
repeated_num_in_dev_matrix_right_ = false;
if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) {
repeated_num_in_dev_matrix_right_ = true;
}
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;

View File

@ -1762,6 +1762,10 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
attrs_temp[STRATEGY] = strategy;
(void)prim->SetAttrs(attrs_temp);
if (prim->HasAttr(REPEAT_DIM_DIRECT) && GetValue<std::string>(prim->GetAttr(REPEAT_DIM_DIRECT)) == RIGHT) {
ParallelContext::GetInstance()->set_dataset_repeat_dim_right(true);
MS_LOG(INFO) << "dataset repeat dim is right";
}
return;
}
int64_t dev_num;