forked from mindspore-Ecosystem/mindspore
Add stage information for ops and strategy
This commit is contained in:
parent
55be3c42a5
commit
4ef439e27b
|
@ -63,6 +63,8 @@ void ParallelContext::Reset() {
|
|||
all_reduce_fusion_split_indices_.clear();
|
||||
all_reduce_fusion_split_sizes_.clear();
|
||||
strategy_search_mode_ = DYNAMIC_PROGRAMMING;
|
||||
stages_.clear();
|
||||
pipeline_stage_split_num_ = 0;
|
||||
}
|
||||
|
||||
void ParallelContext::set_device_num(int32_t device_num) {
|
||||
|
@ -83,6 +85,10 @@ void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient
|
|||
|
||||
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
|
||||
|
||||
void ParallelContext::set_pipeline_stage_split_num(const int32_t stage_num) { pipeline_stage_split_num_ = stage_num; }
|
||||
|
||||
void ParallelContext::set_stage(const std::vector<int32_t> &stages) { stages_ = stages; }
|
||||
|
||||
bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) {
|
||||
auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
|
||||
if (iter == PARALLEL_MODE_LIST.end()) {
|
||||
|
|
|
@ -67,6 +67,12 @@ class ParallelContext {
|
|||
void set_device_num(int32_t device_num);
|
||||
int32_t device_num() const { return device_num_; }
|
||||
|
||||
void set_pipeline_stage_split_num(const int32_t stages);
|
||||
int32_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; }
|
||||
|
||||
void set_stage(const std::vector<int32_t> &stages);
|
||||
std::vector<int32_t> stage() const { return stages_; }
|
||||
|
||||
void set_global_rank(int32_t global_rank);
|
||||
int32_t global_rank() const { return global_rank_; }
|
||||
|
||||
|
@ -115,6 +121,8 @@ class ParallelContext {
|
|||
int32_t global_rank_;
|
||||
std::string parallel_mode_;
|
||||
std::string strategy_search_mode_;
|
||||
std::vector<int32_t> stages_;
|
||||
int32_t pipeline_stage_split_num_;
|
||||
bool parameter_broadcast_;
|
||||
bool device_num_is_set_;
|
||||
bool global_rank_is_set_;
|
||||
|
|
|
@ -36,7 +36,8 @@ Stage::Stage(const std::vector<mindspore::parallel::Device> &devices, int num, i
|
|||
// NOTE: '-1' indicates ERROR
|
||||
int Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); }
|
||||
|
||||
bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend) {
|
||||
bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend,
|
||||
const std::vector<int32_t> &stage) {
|
||||
if (device_num <= 0) {
|
||||
MS_LOG(ERROR) << "'device_num' must be positive.";
|
||||
return false;
|
||||
|
@ -68,7 +69,30 @@ bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &back
|
|||
devices.push_back(i);
|
||||
}
|
||||
|
||||
stage_map.push_back(device_num);
|
||||
if (stage.size()) {
|
||||
int32_t summed_value = 0;
|
||||
for (auto begin = stage.begin(); begin != stage.end(); ++begin) {
|
||||
if (*begin <= 0) {
|
||||
MS_LOG(ERROR) << "The value in the pipeline stages should be positive value";
|
||||
return false;
|
||||
}
|
||||
summed_value += *begin;
|
||||
stage_map.push_back(*begin);
|
||||
}
|
||||
|
||||
if (summed_value != device_num) {
|
||||
MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num "
|
||||
<< device_num;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
stage_map.push_back(device_num);
|
||||
}
|
||||
|
||||
for (auto &y : stage_map) {
|
||||
MS_LOG(DEBUG) << "Obtained stage id :" << y;
|
||||
}
|
||||
|
||||
g_device_manager = std::make_shared<DeviceManager>();
|
||||
if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) {
|
||||
MS_LOG(INFO) << "Device initialization succeeds.";
|
||||
|
|
|
@ -70,7 +70,7 @@ class Stage {
|
|||
|
||||
// This method is used for initializing the global DeviceManager 'g_device_manager',
|
||||
// arguments including 'device_num' and 'global_rank'
|
||||
bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend);
|
||||
bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend, const std::vector<int32_t> &stage);
|
||||
|
||||
void CheckGlobalDeviceManager();
|
||||
|
||||
|
|
|
@ -126,9 +126,22 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *ra
|
|||
}
|
||||
}
|
||||
|
||||
Shape current_rank_coordinate = ConvertRankToCoordinate(rank_, dev_shape_);
|
||||
// Convert the global rank to the local rank(The index of the array) to compute the coordinate
|
||||
uint32_t local_rank = 0;
|
||||
for (auto &tmp_rank : dev_list_) {
|
||||
Shape tmp_rank_coordinate = ConvertRankToCoordinate(tmp_rank, dev_shape_);
|
||||
if (tmp_rank == rank_) {
|
||||
break;
|
||||
}
|
||||
++local_rank;
|
||||
}
|
||||
if (local_rank == dev_list_.size()) {
|
||||
MS_LOG(ERROR) << "Rank id: " << local_rank << "is not in the device list.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Shape current_rank_coordinate = ConvertRankToCoordinate((int32_t)local_rank, dev_shape_);
|
||||
for (uint32_t loop_local_rank = 0; loop_local_rank < dev_list_.size(); ++loop_local_rank) {
|
||||
Shape tmp_rank_coordinate = ConvertRankToCoordinate(loop_local_rank, dev_shape_);
|
||||
bool matched = true;
|
||||
for (auto &map : tensor_map) {
|
||||
if (map == MAP_NONE) {
|
||||
|
@ -141,7 +154,7 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *ra
|
|||
}
|
||||
}
|
||||
if (matched) {
|
||||
rank_list->push_back(tmp_rank);
|
||||
rank_list->push_back(dev_list_[loop_local_rank]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
// dropout don't support repeated calculation
|
||||
CheckGlobalDeviceManager();
|
||||
auto input_strategy = strategy->GetInputDim().at(0);
|
||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
|
||||
auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int>());
|
||||
if (IntToSize(product_p) != dev_num) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
|
||||
|
|
|
@ -196,7 +196,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
|
|||
|
||||
// Don't support repeated calc
|
||||
CheckGlobalDeviceManager();
|
||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
|
||||
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
|
||||
if (IntToSize(product_p) < dev_num) {
|
||||
MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
|
||||
|
@ -269,7 +269,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
|
||||
// param_strategy(axis) != 1, Don't support repeated calc
|
||||
CheckGlobalDeviceManager();
|
||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
|
||||
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
|
||||
if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) {
|
||||
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
|
||||
|
@ -346,7 +346,7 @@ Status GatherV2PInfo::InferDevMatrixShape() {
|
|||
out_dev_matrix_shape_ = dev_matrix_shape_;
|
||||
}
|
||||
CheckGlobalDeviceManager();
|
||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
|
||||
auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
|
||||
auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>());
|
||||
if (param_product * index_product < SizeToInt(dev_num)) {
|
||||
|
@ -516,10 +516,11 @@ Status GatherV2PInfo::InferGroup() {
|
|||
if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) {
|
||||
dim = (axis_ + 1) % 2;
|
||||
}
|
||||
|
||||
CheckGlobalDeviceManager();
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id_);
|
||||
int32_t rank = g_device_manager->global_rank();
|
||||
RankList dev_list = g_device_manager->GetDeviceListByStageId(0);
|
||||
DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_);
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) {
|
||||
|
|
|
@ -162,7 +162,8 @@ class OperatorInfo {
|
|||
void set_type(const std::string &type) { type_ = type; }
|
||||
const std::string &type() const { return type_; }
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
|
||||
void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
|
||||
int32_t stage_id() const { return stage_id_; }
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "OpInfo";
|
||||
|
||||
|
@ -205,6 +206,7 @@ class OperatorInfo {
|
|||
std::vector<ValuePtr> input_value_;
|
||||
TypePtr outputs_dtype_;
|
||||
|
||||
int32_t stage_id_ = 0;
|
||||
StrategyPtr strategy_;
|
||||
std::vector<TensorInfo> inputs_tensor_info_;
|
||||
std::vector<TensorInfo> outputs_tensor_info_;
|
||||
|
|
|
@ -55,6 +55,7 @@ constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only";
|
|||
constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only";
|
||||
constexpr char CHECK_SET_STRATEGY_VALID_ONCE_ONLY[] = "check_set_strategy_valid_once_only";
|
||||
constexpr char STRATEGY[] = "strategy";
|
||||
constexpr char STAGE_ATTR[] = "stage";
|
||||
constexpr char GEN_STRATEGY[] = "gen_strategy";
|
||||
constexpr char REDUCE_OP_SUM[] = "sum";
|
||||
constexpr char REDUCE_OP_MAX[] = "max";
|
||||
|
|
|
@ -133,9 +133,9 @@ Status ReduceMethod::InferTensorMap() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
bool IsDataParallelStrategy(const Dimensions &strategy) {
|
||||
bool IsDataParallelStrategy(const Dimensions &strategy, int32_t stage_id) {
|
||||
CheckGlobalDeviceManager();
|
||||
size_t total_dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||
size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
|
||||
if (strategy.empty()) {
|
||||
MS_LOG(EXCEPTION) << "IsDataParallelStrategy: strategy is empty";
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ bool IsDataParallelStrategy(const Dimensions &strategy) {
|
|||
|
||||
Status ReduceMethod::InferForwardCommunication() {
|
||||
Dimensions stra = strategy_->GetInputDim().at(0);
|
||||
if (cross_batch_ && IsDataParallelStrategy(stra)) {
|
||||
if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) {
|
||||
MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
@ -211,7 +211,7 @@ ForwardOp CreatReduceMeanForwardOp(const std::vector<Group> &forward_group, cons
|
|||
|
||||
Status ReduceMeanInfo::InferForwardCommunication() {
|
||||
Dimensions stra = strategy_->GetInputDim().at(0);
|
||||
if (cross_batch_ && IsDataParallelStrategy(stra)) {
|
||||
if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) {
|
||||
MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -998,6 +998,17 @@ OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAtt
|
|||
StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
|
||||
ValueTuplePtr var = attrs[STRATEGY]->cast<ValueTuplePtr>();
|
||||
StrategyPtr strategyPtr;
|
||||
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage();
|
||||
auto res = attrs.find(STAGE_ATTR);
|
||||
int32_t stage_id = 0;
|
||||
if (res != attrs.end()) {
|
||||
stage_id = GetValue<int>(res->second);
|
||||
}
|
||||
if (stage_id && stages.empty()) {
|
||||
MS_LOG(ERROR) << "Find stage id:" << stage_id << " but the pipeline_stages is 0.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString();
|
||||
if (var == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Strategy value is nullptr";
|
||||
|
@ -1016,13 +1027,13 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
|
|||
});
|
||||
strategy.push_back(dim);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue";
|
||||
MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequence";
|
||||
}
|
||||
}
|
||||
if (strategy.empty()) {
|
||||
MS_LOG(EXCEPTION) << "ExtractStrategy:failed to extract strategy";
|
||||
}
|
||||
strategyPtr = NewStrategy(0, strategy);
|
||||
strategyPtr = NewStrategy(stage_id, strategy);
|
||||
}
|
||||
|
||||
return strategyPtr;
|
||||
|
@ -1420,6 +1431,30 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
|||
(void)prim->SetAttrs(attrs_temp);
|
||||
}
|
||||
}
|
||||
// This function aims to check the valid rank and stage in the operations
|
||||
// If the rank is not valid for the given stage, we chose not to init the strategy of the operation
|
||||
// For example stage is [4, 4], and the group_list [[0,1,2,3],[4,5,6,7]]
|
||||
// For stage 0, we require the rank_id is in [0,1,2,3]
|
||||
Status ValidRankCheck(int32_t global_rank, int32_t strategy_stage) {
|
||||
RankList local_group_list = g_device_manager->GetDeviceListByStageId(strategy_stage);
|
||||
int32_t target = global_rank;
|
||||
if (std::any_of(local_group_list.begin(), local_group_list.end(), [target](int32_t a) { return a == target; })) {
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
return Status::FAILED;
|
||||
}
|
||||
|
||||
Status ValidStageCheck(const std::vector<int32_t> &stages, int32_t strategy_stage) {
|
||||
if (stages.size() > 0) {
|
||||
if (strategy_stage >= 0 && strategy_stage < (int32_t)stages.size()) {
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
return Status::FAILED;
|
||||
} else {
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
// load strategy map from checkpoint
|
||||
|
@ -1429,6 +1464,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
||||
}
|
||||
}
|
||||
|
||||
// Get global rank after the checkpoint?
|
||||
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
|
||||
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage();
|
||||
|
||||
for (auto &node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
|
@ -1501,7 +1541,18 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
strategyPtr = ExtractStrategy(attrs);
|
||||
}
|
||||
if (strategyPtr != nullptr) {
|
||||
if (operator_->Init(strategyPtr) == FAILED) {
|
||||
(*operator_).set_stage_id(strategyPtr->GetInputStage());
|
||||
MS_LOG(INFO) << "Extract stage id for op " << prim->name() << " is " << (*operator_).stage_id();
|
||||
if (ValidStageCheck(stages, (*operator_).stage_id()) == FAILED) {
|
||||
MS_LOG(ERROR) << "Find stage " << strategyPtr->GetInputStage() << " for operator " << prim->name()
|
||||
<< " exceeds the global stage size " << stages.size() << '.';
|
||||
return;
|
||||
}
|
||||
// If the strategy is not valid for the given global rank, then we skip the Init of the strategy
|
||||
if (ValidRankCheck(global_rank, (*operator_).stage_id()) == FAILED) {
|
||||
MS_LOG(INFO) << "Find global exceeds the range of the stage, skip the strategy init for operator "
|
||||
<< prim->name();
|
||||
} else if (operator_->Init(strategyPtr) == FAILED) {
|
||||
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
|
||||
}
|
||||
cnode->set_user_data<OperatorInfo>(operator_);
|
||||
|
@ -2416,6 +2467,9 @@ Status ParallelInit() {
|
|||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
int32_t device_num = ParallelContext::GetInstance()->device_num();
|
||||
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
|
||||
int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage();
|
||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
|
@ -2431,6 +2485,26 @@ Status ParallelInit() {
|
|||
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
|
||||
}
|
||||
|
||||
if (device_num <= 0) {
|
||||
MS_LOG(ERROR) << "Invalid device num " << device_num << " , expected a positive device number";
|
||||
return FAILED;
|
||||
}
|
||||
if (split_stage_num > 0) {
|
||||
if (device_num % split_stage_num != 0) {
|
||||
MS_LOG(ERROR) << "Device num " << device_num << " can't be divided by stage num " << split_stage_num
|
||||
<< " , as we support only extract devision now";
|
||||
return FAILED;
|
||||
}
|
||||
for (int i = 0; i < split_stage_num; i++) {
|
||||
stages.push_back(device_num / split_stage_num);
|
||||
}
|
||||
} else if (split_stage_num < 0) {
|
||||
MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << " , expected a positive stage number";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ParallelContext::GetInstance()->set_stage(stages);
|
||||
|
||||
uint32_t world_rank_size = 0;
|
||||
if (!ParallelContext::GetInstance()->device_num_is_set()) {
|
||||
if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
|
||||
|
@ -2449,7 +2523,12 @@ Status ParallelInit() {
|
|||
MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
|
||||
}
|
||||
|
||||
if (!InitDevice(device_num, global_rank, communication_backend)) {
|
||||
if (!stages.empty() && parallel_mode != SEMI_AUTO_PARALLEL) {
|
||||
MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << SEMI_AUTO_PARALLEL;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (!InitDevice(device_num, global_rank, communication_backend, stages)) {
|
||||
MS_LOG(ERROR) << "Init device failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -2457,6 +2536,7 @@ Status ParallelInit() {
|
|||
MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank
|
||||
<< ", backend: " << backend << ", gradients_mean: " << ParallelContext::GetInstance()->gradients_mean()
|
||||
<< ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync();
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -152,6 +152,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Set strategy checkpoint save file.")
|
||||
.def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
|
||||
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
|
||||
.def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num,
|
||||
"Set pipeline stage split num.")
|
||||
.def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.")
|
||||
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
|
||||
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
|
||||
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
|
||||
|
|
|
@ -331,7 +331,7 @@ def _context():
|
|||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
|
||||
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
all_reduce_fusion_config=list)
|
||||
all_reduce_fusion_config=list, pipeline_stages=int)
|
||||
def set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
Set auto parallel context.
|
||||
|
@ -357,6 +357,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
parallel_mode strategy_ckpt_load_file
|
||||
all_reduce_fusion_config strategy_ckpt_save_file
|
||||
full_batch
|
||||
pipeline_stages
|
||||
=========================== =========================== =================
|
||||
|
||||
Args:
|
||||
|
@ -399,6 +400,10 @@ def set_auto_parallel_context(**kwargs):
|
|||
the fusion is closed.
|
||||
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
|
||||
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed.
|
||||
pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
|
||||
the devices are distributed alone the pipeline. The total devices will be divided into
|
||||
'pipeline_stags' stages. This currently could only be used when
|
||||
parall mode semi_auto_parallel is enabled.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
@ -416,10 +421,10 @@ def set_auto_parallel_context(**kwargs):
|
|||
>>> context.set_auto_parallel_context(full_batch=True)
|
||||
>>> context.set_auto_parallel_context(enable_parallel_optimizer=False)
|
||||
>>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
|
||||
>>> context.set_auto_parallel_context(pipeline_stages=2)
|
||||
"""
|
||||
_set_auto_parallel_context(**kwargs)
|
||||
|
||||
|
||||
def get_auto_parallel_context(attr_key):
|
||||
"""
|
||||
Gets auto parallel context attribute value according to the key.
|
||||
|
|
|
@ -102,6 +102,20 @@ class Primitive(Primitive_):
|
|||
self.add_attr(name, value)
|
||||
return self
|
||||
|
||||
def set_stage(self, stage):
|
||||
"""
|
||||
Add stage id to primitive attribute.
|
||||
|
||||
Note:
|
||||
It is valid only in semi auto parallel.
|
||||
In other parallel modes, please set it to be 0.
|
||||
|
||||
Args:
|
||||
stage (int): The stage id for the current operation
|
||||
"""
|
||||
self.add_prim_attr("stage", stage)
|
||||
return self
|
||||
|
||||
def shard(self, strategy):
|
||||
"""
|
||||
Add strategies to primitive attribute.
|
||||
|
|
|
@ -95,6 +95,16 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_global_rank()
|
||||
|
||||
def set_pipeline_stages(self, stages):
|
||||
"""Set the stages of the pipeline"""
|
||||
self.check_context_handle()
|
||||
self._context_handle.set_pipeline_stage_split_num(stages)
|
||||
|
||||
def get_pipeline_stages(self):
|
||||
"""Get the stages of the pipeline"""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_pipeline_stage_split_num()
|
||||
|
||||
def set_gradients_mean(self, gradients_mean):
|
||||
"""
|
||||
Set gradients_mean flag.
|
||||
|
@ -463,6 +473,7 @@ _set_auto_parallel_context_func_map = {
|
|||
"gradients_mean": auto_parallel_context().set_gradients_mean,
|
||||
"gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
|
||||
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
|
||||
"pipeline_stages": auto_parallel_context().set_pipeline_stages,
|
||||
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
||||
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
|
||||
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
||||
|
@ -479,6 +490,7 @@ _get_auto_parallel_context_func_map = {
|
|||
"gradients_mean": auto_parallel_context().get_gradients_mean,
|
||||
"gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
|
||||
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
|
||||
"pipeline_stages": auto_parallel_context().get_pipeline_stages,
|
||||
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
||||
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
|
||||
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
||||
|
@ -566,7 +578,6 @@ def _get_auto_parallel_context(attr_key):
|
|||
get_func = _get_auto_parallel_context_func_map[attr_key]
|
||||
return get_func()
|
||||
|
||||
|
||||
def _reset_auto_parallel_context():
|
||||
"""
|
||||
Reset auto parallel context attributes to the default values:
|
||||
|
@ -581,5 +592,6 @@ def _reset_auto_parallel_context():
|
|||
- strategy_ckpt_save_file: ""
|
||||
- enable_parallel_optimizer: False
|
||||
- auto_parallel_search_mode: dynamic_programming
|
||||
- pipeline_stages: 0
|
||||
"""
|
||||
auto_parallel_context().reset()
|
||||
|
|
|
@ -83,6 +83,39 @@ TEST_F(TestDeviceMatrix, TestCornerCaseGetAlongDim) {
|
|||
EXPECT_THROW({ DeviceMatrix arr(3, dev_list, shape); }, std::runtime_error);
|
||||
}
|
||||
|
||||
TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceOne) {
|
||||
RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0};
|
||||
Shape tensor_map = {-1, 0};
|
||||
RankList rank_list;
|
||||
Shape shape = {4, 2};
|
||||
DeviceMatrix arr(0, dev_list, shape);
|
||||
arr.GetDevicesByTensorMap(tensor_map, &rank_list);
|
||||
RankList rank_list_except = {3, 9, 100, 0};
|
||||
ASSERT_EQ(rank_list, rank_list_except);
|
||||
}
|
||||
|
||||
TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceTwo) {
|
||||
RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0};
|
||||
Shape tensor_map = {1, 0};
|
||||
RankList rank_list;
|
||||
Shape shape = {4, 2};
|
||||
DeviceMatrix arr(0, dev_list, shape);
|
||||
arr.GetDevicesByTensorMap(tensor_map, &rank_list);
|
||||
RankList rank_list_except = {0};
|
||||
ASSERT_EQ(rank_list, rank_list_except);
|
||||
}
|
||||
|
||||
TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapNoramalOrder2D) {
|
||||
RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
Shape tensor_map = {-1, 0};
|
||||
RankList rank_list;
|
||||
Shape shape = {4, 2};
|
||||
DeviceMatrix arr(6, dev_list, shape);
|
||||
arr.GetDevicesByTensorMap(tensor_map, &rank_list);
|
||||
RankList rank_list_except = {0, 2, 4, 6};
|
||||
ASSERT_EQ(rank_list, rank_list_except);
|
||||
}
|
||||
|
||||
TEST_F(TestDeviceMatrix, TestCornerCase2GetAlongDim) {
|
||||
// Rank is out of range
|
||||
RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0, stage1=0, stage2=0, strategy1=None, strategy2=None, shape=None, target=""):
|
||||
super().__init__()
|
||||
if shape is None:
|
||||
shape = [64, 64]
|
||||
self.gatherv2 = P.GatherV2().shard(strategy1).add_prim_attr("primitive_target", target)
|
||||
self.mul = P.Mul().shard(strategy2)
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.gatherv2.set_stage(stage1)
|
||||
self.mul.set_stage(stage2)
|
||||
self.axis = axis
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.gatherv2(x, self.index, self.axis)
|
||||
out = self.mul(out, y)
|
||||
return out
|
||||
|
||||
|
||||
def test_gatherv2_semi_samestage1():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, \
|
||||
parallel_mode="semi_auto_parallel", pipeline_stages=2)
|
||||
strategy1 = ((1, 2), (1, 1))
|
||||
strategy2 = ((2, 1, 1), (2, 1, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, 0, 0, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
def test_gatherv2_semi_samestage2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=5, \
|
||||
parallel_mode="semi_auto_parallel", pipeline_stages=2)
|
||||
strategy1 = ((1, 2), (1, 1))
|
||||
strategy2 = ((2, 1, 1), (2, 1, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, 1, 1, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
|
@ -81,6 +81,11 @@ def test_set_auto_parallel_context():
|
|||
assert context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||
|
||||
def test_pipeline_parallel_context():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=4,
|
||||
parallel_mode="semi_auto_parallel", pipeline_stages=2)
|
||||
stage = auto_parallel_context().get_pipeline_stages()
|
||||
assert stage == 2
|
||||
|
||||
def test_reset_auto_parallel_context():
|
||||
context.reset_auto_parallel_context()
|
||||
|
@ -92,6 +97,8 @@ def test_reset_auto_parallel_context():
|
|||
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
|
||||
device_num_is_set = auto_parallel_context().get_device_num_is_set()
|
||||
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
|
||||
stage = auto_parallel_context().get_pipeline_stages()
|
||||
|
||||
assert device_num == 1
|
||||
assert global_rank == 0
|
||||
assert not gradients_mean
|
||||
|
@ -100,3 +107,4 @@ def test_reset_auto_parallel_context():
|
|||
assert not parameter_broadcast
|
||||
assert not device_num_is_set
|
||||
assert not parameter_broadcast_is_set
|
||||
assert not stage
|
||||
|
|
Loading…
Reference in New Issue