forked from mindspore-Ecosystem/mindspore
!859 [Auto parallel] Support searching strategy for inference phase
Merge pull request !859 from Xiaoda/support-inferring-phase-parallel-strategy-searching
This commit is contained in:
commit
6351f9c837
|
@ -23,8 +23,17 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
void Simplify(CostPtrList *clist_ptrs) {
|
||||
// Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method
|
||||
// excludes the cost with greater computation_cost_ and greater communication_cost.
|
||||
if (RUN_PHASE == TRAINING_PHASE) {
|
||||
// training phase
|
||||
SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs);
|
||||
} else {
|
||||
// inference phase
|
||||
SimplifyForDecreasingCommunicationForward(clist_ptrs);
|
||||
}
|
||||
}
|
||||
void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
|
||||
// Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method
|
||||
// excludes the cost with greater computation_cost_ and greater communication_forward.
|
||||
// E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
|
||||
if (!COST_MODEL_SIMPLIFY_CALCULATION) {
|
||||
return;
|
||||
|
@ -37,14 +46,15 @@ void Simplify(CostPtrList *clist_ptrs) {
|
|||
});
|
||||
CostPtrList ret;
|
||||
for (size_t i = 0; i < clist_ptrs->size(); ++i) {
|
||||
if ((ret.size() == size_t(0)) || (clist_ptrs->at(id[i])->communication_cost_ < ret.back()->communication_cost_)) {
|
||||
if ((ret.size() == size_t(0)) ||
|
||||
(clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) {
|
||||
ret.emplace_back(std::move(clist_ptrs->at(id[i])));
|
||||
}
|
||||
}
|
||||
*clist_ptrs = std::move(ret);
|
||||
}
|
||||
|
||||
void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
|
||||
void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
|
||||
// Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
|
||||
// order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
|
||||
if (!COST_MODEL_SIMPLIFY_CALCULATION) {
|
||||
|
|
|
@ -51,18 +51,22 @@ struct Cost {
|
|||
communication_with_partial_para_ = 0.0;
|
||||
communication_redis_forward_ = 0.0;
|
||||
communication_redis_backward_ = 0.0;
|
||||
communication_forward_ = 0.0;
|
||||
}
|
||||
// 'memory_with_reuse_' calculates the peak memory usage in a training phase
|
||||
double memory_with_reuse_;
|
||||
// 'computation_cost_' models the training time of an iteration in a training phase
|
||||
// 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated
|
||||
// by ONLY forward phase
|
||||
double computation_cost_;
|
||||
// 'communication_cost_' includes communications from operators (forward and backward) and edges
|
||||
// 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution)
|
||||
double communication_cost_;
|
||||
// communication_without_parameter_ = communication_cost_ - (backward communication from operators)
|
||||
double communication_without_parameter_;
|
||||
// communication_with_partial_para_ =
|
||||
// communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ )
|
||||
double communication_with_partial_para_;
|
||||
// communication_forward_ = communication cost from operators (only forward phase) and forward redistribution.
|
||||
double communication_forward_;
|
||||
double communication_redis_forward_;
|
||||
double communication_redis_backward_;
|
||||
std::shared_ptr<Decision> decision_ptr_;
|
||||
|
@ -296,7 +300,8 @@ using FinalDecisionPtr = std::shared_ptr<FinalDecision>;
|
|||
using FinalSingleDecisionPtr = std::shared_ptr<FinalSingleDecision>;
|
||||
|
||||
void Simplify(CostPtrList *clist);
|
||||
void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist);
|
||||
void SimplifyForDecreasingCommunicationForward(CostPtrList *clist);
|
||||
void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist);
|
||||
void RefineForPracticalCost(const CostPtr &, bool is_redistribution);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,6 +76,7 @@ Status Edge::InitEdgeCost() {
|
|||
<< ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
|
||||
// refine communication cost calculation for practice
|
||||
RefineForPracticalCost(cost, true);
|
||||
cost->communication_forward_ = cost->communication_redis_forward_;
|
||||
CostPtrKey ck = {target_output_str, target_input_str};
|
||||
CostPtrList cl;
|
||||
cl.push_back(cost);
|
||||
|
@ -160,8 +161,9 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
|
|||
(void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
|
||||
|
||||
CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
|
||||
std::function<void(size_t, double, double, double, double)> recursive =
|
||||
[&](size_t k, double computation, double memory, double communication, double communication_without_para) {
|
||||
std::function<void(size_t, double, double, double, double, double)> recursive =
|
||||
[&](size_t k, double computation, double memory, double communication, double communication_without_para,
|
||||
double communication_forward) {
|
||||
if (k == edges.size()) {
|
||||
auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
|
||||
CostPtr new_cost = std::make_shared<Cost>(computation, communication);
|
||||
|
@ -170,6 +172,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
|
|||
new_cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
new_cost->memory_with_reuse_ = memory;
|
||||
new_cost->communication_forward_ = communication_forward;
|
||||
new_cost->decision_ptr_ = decision;
|
||||
result.push_back(new_cost);
|
||||
return;
|
||||
|
@ -179,11 +182,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
|
|||
selected_cost_list[k] = c;
|
||||
recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
|
||||
communication + c->communication_cost_,
|
||||
communication_without_para + c->communication_without_parameter_);
|
||||
communication_without_para + c->communication_without_parameter_,
|
||||
communication_forward + c->communication_forward_);
|
||||
}
|
||||
};
|
||||
recursive(0, 0.0, 0.0, 0.0, 0.0);
|
||||
SimplifyForDreasingCommunicationWithPartialPara(&result);
|
||||
recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0);
|
||||
Simplify(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -219,6 +223,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
|
|||
left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
|
||||
double communication =
|
||||
left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
|
||||
double communication_forward =
|
||||
left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_;
|
||||
double communication_without_para = left_cost->communication_without_parameter_ +
|
||||
middle_cost->communication_without_parameter_ +
|
||||
right_cost->communication_without_parameter_;
|
||||
|
@ -232,6 +238,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
|
|||
cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
cost->memory_with_reuse_ = memory_cost;
|
||||
cost->communication_forward_ = communication_forward;
|
||||
ret_cost_list->emplace_back(std::move(cost));
|
||||
}
|
||||
}
|
||||
|
@ -251,7 +258,7 @@ CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyP
|
|||
CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy),
|
||||
op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result);
|
||||
}
|
||||
SimplifyForDreasingCommunicationWithPartialPara(&result);
|
||||
Simplify(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,6 +38,8 @@ bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
|
|||
size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
|
||||
bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
|
||||
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
|
||||
int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
|
||||
|
||||
void CostGraph::SetDeviceMemoryAndCostParameter() {
|
||||
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
|
||||
|
@ -142,6 +144,23 @@ void CostGraph::SetDeviceMemoryAndCostParameter() {
|
|||
} else {
|
||||
MS_LOG(INFO) << "elementwise_op_strategy_follow: false.";
|
||||
}
|
||||
|
||||
// MULTI_SUBGRAPHS
|
||||
auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs();
|
||||
MULTI_SUBGRAPHS = multi_subgraphs;
|
||||
if (MULTI_SUBGRAPHS) {
|
||||
MS_LOG(INFO) << "multi_subgraphs: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "multi_subgraphs: false.";
|
||||
}
|
||||
|
||||
// RUN_PHASE
|
||||
auto phase = CostModelContext::GetInstance()->run_phase();
|
||||
if (phase != 0 && phase != 1) {
|
||||
MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}";
|
||||
}
|
||||
RUN_PHASE = phase;
|
||||
MS_LOG(INFO) << "run_phase: " << RUN_PHASE << ".";
|
||||
}
|
||||
|
||||
void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
|
||||
|
@ -249,19 +268,21 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
|
|||
MS_EXCEPTION_IF_NULL(cost3);
|
||||
double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
|
||||
double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
|
||||
double commmunication =
|
||||
cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
|
||||
double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
|
||||
double communication_forward =
|
||||
cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_;
|
||||
double communication_without_para = cost1->communication_without_parameter_ +
|
||||
cost2->communication_without_parameter_ +
|
||||
cost3->communication_without_parameter_;
|
||||
auto decision =
|
||||
std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
|
||||
auto cost = std::make_shared<Cost>(computation, commmunication, decision);
|
||||
auto cost = std::make_shared<Cost>(computation, communication, decision);
|
||||
MS_EXCEPTION_IF_NULL(cost);
|
||||
cost->communication_without_parameter_ = communication_without_para;
|
||||
cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (commmunication - communication_without_para);
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
cost->memory_with_reuse_ = memory;
|
||||
cost->communication_forward_ = communication_forward;
|
||||
ret.push_back(cost);
|
||||
}
|
||||
}
|
||||
|
@ -269,7 +290,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
|
|||
}
|
||||
}
|
||||
|
||||
SimplifyForDreasingCommunicationWithPartialPara(&ret);
|
||||
Simplify(&ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -291,32 +312,67 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
|
|||
cost1->communication_without_parameter_ +
|
||||
COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_);
|
||||
new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
|
||||
new_cost->communication_forward_ = cost1->communication_forward_;
|
||||
ret.push_back(new_cost);
|
||||
}
|
||||
}
|
||||
|
||||
SimplifyForDreasingCommunicationWithPartialPara(&ret);
|
||||
Simplify(&ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory) {
|
||||
CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) {
|
||||
// Select the cost with minimum inference time. Currently, the inference time is modeled as =
|
||||
// costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_
|
||||
if (cost_list.empty()) {
|
||||
MS_LOG(ERROR) << "Final cost list is null.";
|
||||
return nullptr;
|
||||
}
|
||||
CostPtrList after_mem_filter;
|
||||
// Filter out the valid costs
|
||||
double minimum_memory = DBL_MAX;
|
||||
// Filter out the valid costs.
|
||||
for (auto &a_cost : cost_list) {
|
||||
if (a_cost->memory_with_reuse_ <= memory) {
|
||||
after_mem_filter.emplace_back(std::move(a_cost));
|
||||
} else if (a_cost->memory_with_reuse_ < minimum_memory) {
|
||||
minimum_memory = a_cost->memory_with_reuse_;
|
||||
}
|
||||
}
|
||||
if (after_mem_filter.empty()) {
|
||||
MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
|
||||
<< ", the memory capacity is: " << memory << ".";
|
||||
return nullptr;
|
||||
}
|
||||
// Init the returned value with first cost.
|
||||
CostPtr ret = after_mem_filter[0];
|
||||
|
||||
std::function<CostPtr(CostPtr, const CostPtr &)> LocalCompare = [&](CostPtr init, const CostPtr &cost_x) {
|
||||
MS_EXCEPTION_IF_NULL(cost_x);
|
||||
if (init == nullptr || cost_x->computation_cost_ < memory) {
|
||||
init = cost_x;
|
||||
double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_;
|
||||
MS_LOG(INFO) << "Cost 0: "
|
||||
<< "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
|
||||
<< ", communication_forward_: " << ret->communication_forward_
|
||||
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
|
||||
<< ", communication_cost_: " << ret->communication_cost_
|
||||
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
|
||||
MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
|
||||
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
|
||||
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
|
||||
<< ", computation_cost_: " << after_mem_filter[i]->computation_cost_
|
||||
<< ", communication_forward_: " << after_mem_filter[i]->communication_forward_
|
||||
<< ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
|
||||
<< ", communication_cost_: " << after_mem_filter[i]->communication_cost_
|
||||
<< ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
|
||||
<< ".";
|
||||
auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
|
||||
costmodel_beta_ * after_mem_filter[i]->communication_forward_;
|
||||
MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
|
||||
if (minimum > tmp) {
|
||||
minimum = tmp;
|
||||
ret = after_mem_filter[i];
|
||||
MS_LOG(INFO) << "Selected: " << i;
|
||||
}
|
||||
return init;
|
||||
};
|
||||
CostPtr ret = nullptr;
|
||||
return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) {
|
||||
|
@ -524,12 +580,26 @@ Status CostGraph::SearchStrategy() {
|
|||
});
|
||||
|
||||
if (alive_ops.size() > 2) {
|
||||
return SearchStrategyForMultiNodeFinalGraph(alive_ops);
|
||||
if (RUN_PHASE == TRAINING_PHASE) {
|
||||
// training phase
|
||||
return SearchStrategyForMultiNodeFinalGraph(alive_ops);
|
||||
} else {
|
||||
// inference phase
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "Currently, searching strategy for the multi-node final graph in inference phase is not supported.";
|
||||
}
|
||||
} else if (alive_ops.size() == 1) {
|
||||
MS_LOG(INFO) << "There are 1 single node in the final graph.";
|
||||
OperatorInfoPtr u = alive_ops[0];
|
||||
auto cost_list = CreateFinalSingleCostList(u);
|
||||
auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
|
||||
CostPtr cost = nullptr;
|
||||
if (RUN_PHASE == TRAINING_PHASE) {
|
||||
// training phase
|
||||
cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
|
||||
} else {
|
||||
// inference phase
|
||||
cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_);
|
||||
}
|
||||
if (cost == nullptr) {
|
||||
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
|
||||
return FAILED;
|
||||
|
@ -575,7 +645,15 @@ Status CostGraph::SearchStrategy() {
|
|||
auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
|
||||
all_list.push_back(cost_list);
|
||||
}
|
||||
auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
|
||||
CostPtrList selected_cost_list;
|
||||
if (RUN_PHASE == TRAINING_PHASE) {
|
||||
// training phase
|
||||
selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
|
||||
} else {
|
||||
// inference phase
|
||||
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
|
||||
"phase is not supported.";
|
||||
}
|
||||
for (size_t k = 0; k < selected_cost_list.size(); ++k) {
|
||||
auto selected_cost = selected_cost_list[k];
|
||||
if (selected_cost == nullptr) {
|
||||
|
@ -601,7 +679,14 @@ Status CostGraph::SearchStrategy() {
|
|||
auto e = u->GetAliveSuccEdges()[0];
|
||||
MS_EXCEPTION_IF_NULL(e);
|
||||
auto cost_list = CreateFinalCostList(u, e, v);
|
||||
auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
|
||||
CostPtr cost = nullptr;
|
||||
if (RUN_PHASE == TRAINING_PHASE) {
|
||||
// training phase
|
||||
cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
|
||||
"phase is not supported.";
|
||||
}
|
||||
if (cost == nullptr) {
|
||||
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
|
||||
return FAILED;
|
||||
|
@ -841,6 +926,8 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
|
|||
double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
|
||||
double communication =
|
||||
op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
|
||||
double communication_forward =
|
||||
op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_;
|
||||
double communication_without_para = op_cost->communication_without_parameter_ +
|
||||
edge_cost->communication_without_parameter_ +
|
||||
tar_cost->communication_without_parameter_;
|
||||
|
@ -853,6 +940,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
|
|||
new_cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
new_cost->memory_with_reuse_ = memory;
|
||||
new_cost->communication_forward_ = communication_forward;
|
||||
MS_EXCEPTION_IF_NULL(tar_cost_list_new);
|
||||
tar_cost_list_new->emplace_back(std::move(new_cost));
|
||||
}
|
||||
|
@ -885,7 +973,7 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) {
|
|||
|
||||
CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
|
||||
}
|
||||
SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new);
|
||||
Simplify(&tar_clist_new);
|
||||
// Set the new costlist w.r.t the strategy
|
||||
tar_stra_cost->cost_list = tar_clist_new;
|
||||
if ((!valid) && (!tar_clist_new.empty())) {
|
||||
|
@ -922,6 +1010,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
|
|||
contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
|
||||
double communication =
|
||||
contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
|
||||
double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ +
|
||||
tar_cost->communication_forward_;
|
||||
double communication_without_para = contract_op_cost->communication_without_parameter_ +
|
||||
edge_cost->communication_without_parameter_ +
|
||||
tar_cost->communication_without_parameter_;
|
||||
|
@ -933,6 +1023,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
|
|||
new_cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
new_cost->memory_with_reuse_ = memory;
|
||||
new_cost->communication_forward_ = communication_forward;
|
||||
tar_cost_list_new->emplace_back(std::move(new_cost));
|
||||
}
|
||||
}
|
||||
|
@ -962,7 +1053,7 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) {
|
|||
|
||||
CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
|
||||
}
|
||||
SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new);
|
||||
Simplify(&tar_clist_new);
|
||||
// Set the new costlist w.r.t the strategy
|
||||
tar_stra_cost->cost_list = tar_clist_new;
|
||||
if ((!valid) && (!tar_clist_new.empty())) {
|
||||
|
@ -998,6 +1089,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
|
|||
left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
|
||||
double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
|
||||
left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
|
||||
double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ +
|
||||
left_node_cost->communication_forward_ + right_edge_cost->communication_forward_;
|
||||
double new_commu_without =
|
||||
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
|
||||
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
|
||||
|
@ -1009,6 +1102,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
|
|||
new_cost->communication_with_partial_para_ =
|
||||
new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without);
|
||||
new_cost->memory_with_reuse_ = new_memory;
|
||||
new_cost->communication_forward_ = new_commu_forward;
|
||||
left_node_clist_new->emplace_back(std::move(new_cost));
|
||||
}
|
||||
}
|
||||
|
@ -1079,7 +1173,7 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
|
|||
&left_node_clist_new);
|
||||
}
|
||||
}
|
||||
SimplifyForDreasingCommunicationWithPartialPara(&left_node_clist_new);
|
||||
Simplify(&left_node_clist_new);
|
||||
// Set the new costlist w.r.t the strategy
|
||||
left_node_stra_cost->cost_list = left_node_clist_new;
|
||||
if ((!valid) && (!left_node_clist_new.empty())) {
|
||||
|
@ -1112,19 +1206,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
|
|||
|
||||
double computation_cost = merged_node_cost->computation_cost_,
|
||||
memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
|
||||
commu_without = merged_node_cost->communication_without_parameter_;
|
||||
commu_without = merged_node_cost->communication_without_parameter_,
|
||||
commu_forward = merged_node_cost->communication_forward_;
|
||||
for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
|
||||
if (i == 0) {
|
||||
computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
|
||||
memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
|
||||
commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
|
||||
commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_;
|
||||
commu_without += succ_edges_costs[i]->communication_without_parameter_ +
|
||||
succ_nodes_costs[i]->communication_without_parameter_;
|
||||
} else {
|
||||
computation_cost += succ_edges_costs[i]->computation_cost_;
|
||||
memory_cost += succ_edges_costs[i]->memory_with_reuse_;
|
||||
commu_cost += succ_edges_costs[i]->communication_cost_;
|
||||
commu_forward += succ_edges_costs[i]->communication_forward_;
|
||||
commu_without += succ_edges_costs[i]->communication_without_parameter_;
|
||||
}
|
||||
}
|
||||
|
@ -1135,6 +1232,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
|
|||
new_cost->communication_without_parameter_ = commu_without;
|
||||
new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
|
||||
new_cost->memory_with_reuse_ = memory_cost;
|
||||
new_cost->communication_forward_ = commu_forward;
|
||||
first_succ_node_clist_new->emplace_back(std::move(new_cost));
|
||||
}
|
||||
}
|
||||
|
@ -1220,7 +1318,7 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
|
|||
CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
|
||||
merged_op_stra, merged_op_clist, &first_succ_node_clist_new);
|
||||
}
|
||||
SimplifyForDreasingCommunicationWithPartialPara(&first_succ_node_clist_new);
|
||||
Simplify(&first_succ_node_clist_new);
|
||||
// Set the new costlist w.r.t the strategy
|
||||
first_succ_node_stra_cost->cost_list = first_succ_node_clist_new;
|
||||
if ((!valid) && (!first_succ_node_clist_new.empty())) {
|
||||
|
|
|
@ -45,6 +45,9 @@ namespace parallel {
|
|||
#define DEFAULT_FULLY_USE_DEVICES true
|
||||
#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
|
||||
#define DEFAULT_IS_MULTI_SUBGRAPHS false
|
||||
#define DEFAULT_RUN_PHASE 0
|
||||
#define TRAINING_PHASE 0
|
||||
#define INFERENCE_PHASE 1
|
||||
|
||||
class CostGraph;
|
||||
using CostGraphPtr = std::shared_ptr<CostGraph>;
|
||||
|
@ -60,6 +63,8 @@ extern bool TENSOR_SLICE_ALIGNMENT_ENABLE;
|
|||
extern size_t TENSOR_SLICE_ALIGNMENT_SIZE;
|
||||
extern bool FULLY_USE_DEVICES;
|
||||
extern bool ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
extern bool MULTI_SUBGRAPHS;
|
||||
extern int32_t RUN_PHASE;
|
||||
|
||||
class CostGraph {
|
||||
// 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
|
||||
|
@ -98,7 +103,7 @@ class CostGraph {
|
|||
|
||||
CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v);
|
||||
CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u);
|
||||
CostPtr SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory);
|
||||
CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory);
|
||||
CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory);
|
||||
CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory);
|
||||
Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &);
|
||||
|
|
|
@ -47,6 +47,7 @@ void CostModelContext::ResetCostModel() {
|
|||
costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST;
|
||||
costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS;
|
||||
is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS;
|
||||
run_phase_ = DEFAULT_RUN_PHASE;
|
||||
costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM;
|
||||
costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES;
|
||||
costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT;
|
||||
|
@ -125,5 +126,7 @@ void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_
|
|||
void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) {
|
||||
elementwise_stra_follow_ = elementwise_follow;
|
||||
}
|
||||
|
||||
void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -113,6 +113,9 @@ class CostModelContext {
|
|||
void set_elementwise_stra_follow(bool);
|
||||
bool elementwise_stra_follow() const { return elementwise_stra_follow_; }
|
||||
|
||||
void set_run_phase(int32_t);
|
||||
int32_t run_phase() const { return run_phase_; }
|
||||
|
||||
private:
|
||||
CostModelContext();
|
||||
static std::shared_ptr<CostModelContext> cm_context_inst_;
|
||||
|
@ -141,8 +144,11 @@ class CostModelContext {
|
|||
// COST_MODEL_COMMUNI_BIAS
|
||||
double costmodel_communi_bias_;
|
||||
|
||||
// MULTI_SUBGRAPHS
|
||||
bool is_multi_subgraphs_;
|
||||
|
||||
int32_t run_phase_; // 0: 'training', 1: 'inference'
|
||||
|
||||
int32_t costmodel_allreduce_fusion_algorithm_;
|
||||
|
||||
int32_t costmodel_allreduce_fusion_times_;
|
||||
|
|
|
@ -610,6 +610,7 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &
|
|||
<< ", communication_with_partial_para_: " << result->communication_with_partial_para_;
|
||||
// refine communication cost calculation for practice
|
||||
RefineForPracticalCost(result, false);
|
||||
result->communication_forward_ = result->communication_without_parameter_;
|
||||
|
||||
std::shared_ptr<StrategyWithCost> swc =
|
||||
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
|
||||
|
|
|
@ -1049,6 +1049,7 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
|
|||
BreakingTiesForPerferringDataParallel(strategy, result);
|
||||
// refine communication cost calculation for practice
|
||||
RefineForPracticalCost(result, false);
|
||||
result->communication_forward_ = result->communication_without_parameter_;
|
||||
|
||||
std::shared_ptr<StrategyWithCost> swc =
|
||||
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
|
||||
|
|
|
@ -69,16 +69,16 @@ class TensorRedistribution {
|
|||
RankList dev_list_;
|
||||
OperatorList operator_list_;
|
||||
bool reshape_flag_;
|
||||
// communication cost
|
||||
// communication cost, which is the sum of forward communication cost and backward communication cost
|
||||
double comm_cost_;
|
||||
// forward communication cost
|
||||
double forward_comm_cost_;
|
||||
// backward communication cost
|
||||
double backward_comm_cost_;
|
||||
// computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the
|
||||
// inputs.
|
||||
// inputs. This is calculated ONLY for forward phase.
|
||||
double computation_cost_;
|
||||
// memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is
|
||||
// memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is
|
||||
// calculated by the outputs.
|
||||
double memory_cost_;
|
||||
bool construct_op_flag_;
|
||||
|
|
|
@ -228,6 +228,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Get the parameter cost_model_communi_bias of the DP algorithm.")
|
||||
.def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.")
|
||||
.def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.")
|
||||
.def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.")
|
||||
.def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.")
|
||||
.def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm,
|
||||
"Set the parameter gradient AllReduce fusion algorithm.")
|
||||
.def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm,
|
||||
|
|
|
@ -239,6 +239,33 @@ class _CostModelContext:
|
|||
raise ValueError("Context handle is none in context!!!")
|
||||
return self._context_handle.get_multi_subgraphs()
|
||||
|
||||
def set_run_phase(self, phase):
|
||||
"""
|
||||
Set the flag of running phase: training (0) or inference (1)
|
||||
|
||||
Args:
|
||||
phase (int): A parameter indicating which phase is running.
|
||||
|
||||
Raises:
|
||||
ValueError: If context handle is none, or phase is not in {0, 1}.
|
||||
"""
|
||||
if self._context_handle is None:
|
||||
raise ValueError("Context handle is none in context!!!")
|
||||
if phase not in (0, 1):
|
||||
raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase))
|
||||
self._context_handle.set_run_phase(phase)
|
||||
|
||||
def get_run_phase(self):
|
||||
"""
|
||||
Get the flag of running phase.
|
||||
|
||||
Raises:
|
||||
ValueError: If context handle is none.
|
||||
"""
|
||||
if self._context_handle is None:
|
||||
raise ValueError("Context handle is none in context!!!")
|
||||
return self._context_handle.get_run_phase()
|
||||
|
||||
def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
|
||||
"""
|
||||
Set costmodel allreduce fusion algorithm.
|
||||
|
@ -453,6 +480,7 @@ set_cost_model_context_func_map = {
|
|||
"costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
|
||||
"costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
|
||||
"multi_subgraphs": cost_model_context().set_multi_subgraphs,
|
||||
"run_phase": cost_model_context().set_run_phase,
|
||||
"costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
|
||||
"costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
|
||||
"costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
|
||||
|
@ -473,7 +501,8 @@ get_cost_model_context_func_map = {
|
|||
"costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
|
||||
"costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
|
||||
"costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
|
||||
"multi_subgraphs": cost_model_context().get_multi_subgraphs(),
|
||||
"multi_subgraphs": cost_model_context().get_multi_subgraphs,
|
||||
"run_phase": cost_model_context().get_run_phase,
|
||||
"costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
|
||||
"costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
|
||||
"costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
|
||||
|
@ -488,7 +517,7 @@ get_cost_model_context_func_map = {
|
|||
|
||||
@args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
|
||||
costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
|
||||
multi_subgraphs=bool,
|
||||
multi_subgraphs=bool, run_phase=int,
|
||||
costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
|
||||
costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
|
||||
costmodel_allreduce_fusion_allreduce_inherent_time=float,
|
||||
|
@ -510,6 +539,7 @@ def set_cost_model_context(**kwargs):
|
|||
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
|
||||
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
|
||||
multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
|
||||
run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
|
||||
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
|
||||
0: bypass allreduce fusion;
|
||||
1: only use backward computation time to group allreduce;
|
||||
|
|
|
@ -371,7 +371,7 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) {
|
|||
ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS);
|
||||
cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2);
|
||||
auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2);
|
||||
cost_graph.SelectCostWithMemoryConstraint(cost_list, cost_graph.GetDeviceMemory());
|
||||
cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory());
|
||||
}
|
||||
|
||||
TEST_F(TestCostGraph, test_EliminationOp) {
|
||||
|
|
|
@ -14,15 +14,21 @@
|
|||
|
||||
import mindspore.context as context
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.parallel._cost_model_context import reset_cost_model_context
|
||||
from mindspore.parallel.algo_parameter_config import reset_algo_parameters
|
||||
from mindspore.parallel._utils import _reset_op_id
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
reset_cost_model_context()
|
||||
reset_algo_parameters()
|
||||
_reset_op_id()
|
||||
|
||||
|
||||
def teardown_module():
|
||||
context.reset_auto_parallel_context()
|
||||
reset_cost_model_context()
|
||||
reset_algo_parameters()
|
||||
_reset_op_id()
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import WithLossCell, TrainOneStepCell
|
||||
from mindspore.nn import Momentum
|
||||
from mindspore.parallel._cost_model_context import set_cost_model_context
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, input_ch, out_ch):
|
||||
super(Net, self).__init__()
|
||||
self.dense = nn.Dense(input_ch, out_ch)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.dense(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
def test_inference_phase():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
set_cost_model_context(run_phase=1)
|
||||
|
||||
net = Net(512, 128)
|
||||
predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.001)
|
||||
label = Tensor(np.ones([64, 128]).astype(np.float32))
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
train_network.set_train()
|
||||
|
||||
output = train_network(predict, label)
|
Loading…
Reference in New Issue