forked from mindspore-Ecosystem/mindspore
!25409 dataset_strategy_repeat_cal_fix
Merge pull request !25409 from yao_yf/dataset_strategy_repeat_cal_fix
This commit is contained in:
commit
0effb635e9
|
@ -131,6 +131,10 @@ class ParallelContext {
|
|||
bool sharding_propagation() const { return sharding_propagation_; }
|
||||
void set_enable_all2all(const bool);
|
||||
bool enable_all2all() const { return enable_all2all_; }
|
||||
void set_dataset_repeat_dim_right(const bool dataset_repeat_dim_right) {
|
||||
dataset_repeat_dim_right_ = dataset_repeat_dim_right;
|
||||
}
|
||||
bool dataset_repeat_dim_right() const { return dataset_repeat_dim_right_; }
|
||||
|
||||
void Reset();
|
||||
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
|
||||
|
@ -173,6 +177,7 @@ class ParallelContext {
|
|||
// Enable AllToAll or not. If false, use AllGather and Split.
|
||||
bool enable_all2all_;
|
||||
std::vector<std::vector<int64_t>> dataset_strategy_;
|
||||
bool dataset_repeat_dim_right_ = false;
|
||||
};
|
||||
|
||||
} // namespace parallel
|
||||
|
|
|
@ -30,19 +30,19 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status GetNextInfo::InferTensorMap() {
|
||||
auto slice_dim_iter = std::find(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), shard_num_);
|
||||
if (slice_dim_iter == dev_matrix_shape_.end()) {
|
||||
auto slice_dim_iter = std::find(dev_matrix_shape_origin_.begin(), dev_matrix_shape_origin_.end(), shard_num_);
|
||||
if (slice_dim_iter == dev_matrix_shape_origin_.end()) {
|
||||
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim.";
|
||||
return FAILED;
|
||||
}
|
||||
size_t slice_dim = size_t(slice_dim_iter - dev_matrix_shape_.begin());
|
||||
size_t slice_dim = size_t(slice_dim_iter - dev_matrix_shape_origin_.begin());
|
||||
for (size_t i = 0; i < dataset_strategy_.size(); i++) {
|
||||
Shape tensor_map_index;
|
||||
for (auto dim : dataset_strategy_[i]) {
|
||||
if (dim == 1) {
|
||||
tensor_map_index.push_back(MAP_NONE);
|
||||
} else if (dim == shard_num_) {
|
||||
tensor_map_index.push_back(dev_matrix_shape_.size() - 1 - slice_dim);
|
||||
tensor_map_index.push_back(dev_matrix_shape_origin_.size() - 1 - slice_dim);
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support fully shard in one dim.";
|
||||
return FAILED;
|
||||
|
@ -95,11 +95,15 @@ Status GetNextInfo::InferDevMatrixShape() {
|
|||
if (shard_num_iter != dev_matrix_shape_.end()) {
|
||||
shard_num_ = *shard_num_iter;
|
||||
}
|
||||
dev_matrix_shape_origin_ = dev_matrix_shape_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GetNextInfo::Init(const StrategyPtr &strategy) {
|
||||
repeated_num_in_dev_matrix_right_ = false;
|
||||
if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) {
|
||||
repeated_num_in_dev_matrix_right_ = true;
|
||||
}
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Init failed";
|
||||
return FAILED;
|
||||
|
|
|
@ -63,6 +63,7 @@ class GetNextInfo : public OperatorInfo {
|
|||
int64_t shard_num_ = 1;
|
||||
std::string shared_name_;
|
||||
Strategys dataset_strategy_;
|
||||
Shape dev_matrix_shape_origin_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -226,6 +226,8 @@ constexpr char IS_TRAINING[] = "is_training";
|
|||
constexpr char EPSILON[] = "epsilon";
|
||||
constexpr char MOMENTUM[] = "momentum";
|
||||
constexpr char DEVICE_NUM[] = "device_num";
|
||||
constexpr char REPEAT_DIM_DIRECT[] = "repeat_dim_direct";
|
||||
constexpr char RIGHT[] = "right";
|
||||
|
||||
// Operator
|
||||
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
||||
|
|
|
@ -94,12 +94,13 @@ Status VirtualDatasetInfo::InferMirrorOps() { return SUCCESS; }
|
|||
Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; }
|
||||
|
||||
Status VirtualDatasetInfo::InferTensorMap() {
|
||||
auto slice_dim_iter = std::find(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), shard_num_);
|
||||
if (slice_dim_iter == dev_matrix_shape_.end()) {
|
||||
auto dev_mat_origin = strategy_->GetInputDim()[max_size_strategy_dim_];
|
||||
auto slice_dim_iter = std::find(dev_mat_origin.begin(), dev_mat_origin.end(), shard_num_);
|
||||
if (slice_dim_iter == dev_mat_origin.end()) {
|
||||
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim.";
|
||||
return FAILED;
|
||||
}
|
||||
size_t slice_dim = size_t(slice_dim_iter - dev_matrix_shape_.begin());
|
||||
size_t slice_dim = size_t(slice_dim_iter - dev_mat_origin.begin());
|
||||
auto stra = strategy_->GetInputDim();
|
||||
for (size_t i = 0; i < stra.size(); i++) {
|
||||
Shape tensor_map_index;
|
||||
|
@ -107,7 +108,7 @@ Status VirtualDatasetInfo::InferTensorMap() {
|
|||
if (dim == 1) {
|
||||
tensor_map_index.push_back(MAP_NONE);
|
||||
} else if (dim == shard_num_) {
|
||||
tensor_map_index.push_back(dev_matrix_shape_.size() - 1 - slice_dim);
|
||||
tensor_map_index.push_back(dev_mat_origin.size() - 1 - slice_dim);
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim.";
|
||||
return FAILED;
|
||||
|
@ -123,6 +124,9 @@ Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; }
|
|||
|
||||
Status VirtualDatasetInfo::Init(const StrategyPtr &strategy) {
|
||||
repeated_num_in_dev_matrix_right_ = false;
|
||||
if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) {
|
||||
repeated_num_in_dev_matrix_right_ = true;
|
||||
}
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
|
|
|
@ -1762,6 +1762,10 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
|||
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
|
||||
attrs_temp[STRATEGY] = strategy;
|
||||
(void)prim->SetAttrs(attrs_temp);
|
||||
if (prim->HasAttr(REPEAT_DIM_DIRECT) && GetValue<std::string>(prim->GetAttr(REPEAT_DIM_DIRECT)) == RIGHT) {
|
||||
ParallelContext::GetInstance()->set_dataset_repeat_dim_right(true);
|
||||
MS_LOG(INFO) << "dataset repeat dim is right";
|
||||
}
|
||||
return;
|
||||
}
|
||||
int64_t dev_num;
|
||||
|
|
Loading…
Reference in New Issue