!859 [Auto parallel] Support searching strategy for inference phase

Merge pull request !859 from Xiaoda/support-inferring-phase-parallel-strategy-searching
This commit is contained in:
mindspore-ci-bot 2020-05-07 17:38:14 +08:00 committed by Gitee
commit 6351f9c837
15 changed files with 255 additions and 45 deletions

View File

@ -23,8 +23,17 @@
namespace mindspore {
namespace parallel {
void Simplify(CostPtrList *clist_ptrs) {
// Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method
// excludes the cost with greater computation_cost_ and greater communication_cost.
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs);
} else {
// inference phase
SimplifyForDecreasingCommunicationForward(clist_ptrs);
}
}
void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
// Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method
// excludes the cost with greater computation_cost_ and greater communication_forward.
// E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
if (!COST_MODEL_SIMPLIFY_CALCULATION) {
return;
@ -37,14 +46,15 @@ void Simplify(CostPtrList *clist_ptrs) {
});
CostPtrList ret;
for (size_t i = 0; i < clist_ptrs->size(); ++i) {
if ((ret.size() == size_t(0)) || (clist_ptrs->at(id[i])->communication_cost_ < ret.back()->communication_cost_)) {
if ((ret.size() == size_t(0)) ||
(clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) {
ret.emplace_back(std::move(clist_ptrs->at(id[i])));
}
}
*clist_ptrs = std::move(ret);
}
void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
// Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
// order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
if (!COST_MODEL_SIMPLIFY_CALCULATION) {

View File

@ -51,18 +51,22 @@ struct Cost {
communication_with_partial_para_ = 0.0;
communication_redis_forward_ = 0.0;
communication_redis_backward_ = 0.0;
communication_forward_ = 0.0;
}
// 'memory_with_reuse_' calculates the peak memory usage in a training phase
double memory_with_reuse_;
// 'computation_cost_' models the training time of an iteration in a training phase
// 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated
// by ONLY forward phase
double computation_cost_;
// 'communication_cost_' includes communications from operators (forward and backward) and edges
// 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution)
double communication_cost_;
// communication_without_parameter_ = communication_cost_ - (backward communication from operators)
double communication_without_parameter_;
// communication_with_partial_para_ =
// communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ )
double communication_with_partial_para_;
// communication_forward_ = communication cost from operators (only forward phase) and forward redistribution.
double communication_forward_;
double communication_redis_forward_;
double communication_redis_backward_;
std::shared_ptr<Decision> decision_ptr_;
@ -296,7 +300,8 @@ using FinalDecisionPtr = std::shared_ptr<FinalDecision>;
using FinalSingleDecisionPtr = std::shared_ptr<FinalSingleDecision>;
void Simplify(CostPtrList *clist);
void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist);
void SimplifyForDecreasingCommunicationForward(CostPtrList *clist);
void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist);
void RefineForPracticalCost(const CostPtr &, bool is_redistribution);
} // namespace parallel
} // namespace mindspore

View File

@ -76,6 +76,7 @@ Status Edge::InitEdgeCost() {
<< ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
// refine communication cost calculation for practice
RefineForPracticalCost(cost, true);
cost->communication_forward_ = cost->communication_redis_forward_;
CostPtrKey ck = {target_output_str, target_input_str};
CostPtrList cl;
cl.push_back(cost);
@ -160,8 +161,9 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
(void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
std::function<void(size_t, double, double, double, double)> recursive =
[&](size_t k, double computation, double memory, double communication, double communication_without_para) {
std::function<void(size_t, double, double, double, double, double)> recursive =
[&](size_t k, double computation, double memory, double communication, double communication_without_para,
double communication_forward) {
if (k == edges.size()) {
auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
CostPtr new_cost = std::make_shared<Cost>(computation, communication);
@ -170,6 +172,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
new_cost->decision_ptr_ = decision;
result.push_back(new_cost);
return;
@ -179,11 +182,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
selected_cost_list[k] = c;
recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
communication + c->communication_cost_,
communication_without_para + c->communication_without_parameter_);
communication_without_para + c->communication_without_parameter_,
communication_forward + c->communication_forward_);
}
};
recursive(0, 0.0, 0.0, 0.0, 0.0);
SimplifyForDreasingCommunicationWithPartialPara(&result);
recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0);
Simplify(&result);
return result;
}
@ -219,6 +223,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
double communication =
left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
double communication_forward =
left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_;
double communication_without_para = left_cost->communication_without_parameter_ +
middle_cost->communication_without_parameter_ +
right_cost->communication_without_parameter_;
@ -232,6 +238,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
cost->memory_with_reuse_ = memory_cost;
cost->communication_forward_ = communication_forward;
ret_cost_list->emplace_back(std::move(cost));
}
}
@ -251,7 +258,7 @@ CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyP
CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy),
op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result);
}
SimplifyForDreasingCommunicationWithPartialPara(&result);
Simplify(&result);
return result;
}

View File

@ -38,6 +38,8 @@ bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
void CostGraph::SetDeviceMemoryAndCostParameter() {
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
@ -142,6 +144,23 @@ void CostGraph::SetDeviceMemoryAndCostParameter() {
} else {
MS_LOG(INFO) << "elementwise_op_strategy_follow: false.";
}
// MULTI_SUBGRAPHS
auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs();
MULTI_SUBGRAPHS = multi_subgraphs;
if (MULTI_SUBGRAPHS) {
MS_LOG(INFO) << "multi_subgraphs: true.";
} else {
MS_LOG(INFO) << "multi_subgraphs: false.";
}
// RUN_PHASE
auto phase = CostModelContext::GetInstance()->run_phase();
if (phase != 0 && phase != 1) {
MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}";
}
RUN_PHASE = phase;
MS_LOG(INFO) << "run_phase: " << RUN_PHASE << ".";
}
void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
@ -249,19 +268,21 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
MS_EXCEPTION_IF_NULL(cost3);
double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
double commmunication =
cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
double communication_forward =
cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_;
double communication_without_para = cost1->communication_without_parameter_ +
cost2->communication_without_parameter_ +
cost3->communication_without_parameter_;
auto decision =
std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
auto cost = std::make_shared<Cost>(computation, commmunication, decision);
auto cost = std::make_shared<Cost>(computation, communication, decision);
MS_EXCEPTION_IF_NULL(cost);
cost->communication_without_parameter_ = communication_without_para;
cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (commmunication - communication_without_para);
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
cost->memory_with_reuse_ = memory;
cost->communication_forward_ = communication_forward;
ret.push_back(cost);
}
}
@ -269,7 +290,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
}
}
SimplifyForDreasingCommunicationWithPartialPara(&ret);
Simplify(&ret);
return ret;
}
@ -291,32 +312,67 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
cost1->communication_without_parameter_ +
COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_);
new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
new_cost->communication_forward_ = cost1->communication_forward_;
ret.push_back(new_cost);
}
}
SimplifyForDreasingCommunicationWithPartialPara(&ret);
Simplify(&ret);
return ret;
}
CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory) {
CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) {
// Select the cost with minimum inference time. Currently, the inference time is modeled as =
// costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_
if (cost_list.empty()) {
MS_LOG(ERROR) << "Final cost list is null.";
return nullptr;
}
CostPtrList after_mem_filter;
// Filter out the valid costs
double minimum_memory = DBL_MAX;
// Filter out the valid costs.
for (auto &a_cost : cost_list) {
if (a_cost->memory_with_reuse_ <= memory) {
after_mem_filter.emplace_back(std::move(a_cost));
} else if (a_cost->memory_with_reuse_ < minimum_memory) {
minimum_memory = a_cost->memory_with_reuse_;
}
}
if (after_mem_filter.empty()) {
MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
<< ", the memory capacity is: " << memory << ".";
return nullptr;
}
// Init the returned value with first cost.
CostPtr ret = after_mem_filter[0];
std::function<CostPtr(CostPtr, const CostPtr &)> LocalCompare = [&](CostPtr init, const CostPtr &cost_x) {
MS_EXCEPTION_IF_NULL(cost_x);
if (init == nullptr || cost_x->computation_cost_ < memory) {
init = cost_x;
double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_;
MS_LOG(INFO) << "Cost 0: "
<< "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
<< ", communication_forward_: " << ret->communication_forward_
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
<< ", communication_cost_: " << ret->communication_cost_
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
<< ", computation_cost_: " << after_mem_filter[i]->computation_cost_
<< ", communication_forward_: " << after_mem_filter[i]->communication_forward_
<< ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
<< ", communication_cost_: " << after_mem_filter[i]->communication_cost_
<< ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
<< ".";
auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
costmodel_beta_ * after_mem_filter[i]->communication_forward_;
MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
if (minimum > tmp) {
minimum = tmp;
ret = after_mem_filter[i];
MS_LOG(INFO) << "Selected: " << i;
}
return init;
};
CostPtr ret = nullptr;
return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare);
}
return ret;
}
CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) {
@ -524,12 +580,26 @@ Status CostGraph::SearchStrategy() {
});
if (alive_ops.size() > 2) {
return SearchStrategyForMultiNodeFinalGraph(alive_ops);
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
return SearchStrategyForMultiNodeFinalGraph(alive_ops);
} else {
// inference phase
MS_LOG(EXCEPTION)
<< "Currently, searching strategy for the multi-node final graph in inference phase is not supported.";
}
} else if (alive_ops.size() == 1) {
MS_LOG(INFO) << "There are 1 single node in the final graph.";
OperatorInfoPtr u = alive_ops[0];
auto cost_list = CreateFinalSingleCostList(u);
auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
CostPtr cost = nullptr;
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
} else {
// inference phase
cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_);
}
if (cost == nullptr) {
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
return FAILED;
@ -575,7 +645,15 @@ Status CostGraph::SearchStrategy() {
auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
all_list.push_back(cost_list);
}
auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
CostPtrList selected_cost_list;
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
} else {
// inference phase
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
"phase is not supported.";
}
for (size_t k = 0; k < selected_cost_list.size(); ++k) {
auto selected_cost = selected_cost_list[k];
if (selected_cost == nullptr) {
@ -601,7 +679,14 @@ Status CostGraph::SearchStrategy() {
auto e = u->GetAliveSuccEdges()[0];
MS_EXCEPTION_IF_NULL(e);
auto cost_list = CreateFinalCostList(u, e, v);
auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
CostPtr cost = nullptr;
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
} else {
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
"phase is not supported.";
}
if (cost == nullptr) {
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
return FAILED;
@ -841,6 +926,8 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
double communication =
op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
double communication_forward =
op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_;
double communication_without_para = op_cost->communication_without_parameter_ +
edge_cost->communication_without_parameter_ +
tar_cost->communication_without_parameter_;
@ -853,6 +940,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
MS_EXCEPTION_IF_NULL(tar_cost_list_new);
tar_cost_list_new->emplace_back(std::move(new_cost));
}
@ -885,7 +973,7 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) {
CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
}
SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new);
Simplify(&tar_clist_new);
// Set the new costlist w.r.t the strategy
tar_stra_cost->cost_list = tar_clist_new;
if ((!valid) && (!tar_clist_new.empty())) {
@ -922,6 +1010,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
double communication =
contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ +
tar_cost->communication_forward_;
double communication_without_para = contract_op_cost->communication_without_parameter_ +
edge_cost->communication_without_parameter_ +
tar_cost->communication_without_parameter_;
@ -933,6 +1023,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
tar_cost_list_new->emplace_back(std::move(new_cost));
}
}
@ -962,7 +1053,7 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) {
CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
}
SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new);
Simplify(&tar_clist_new);
// Set the new costlist w.r.t the strategy
tar_stra_cost->cost_list = tar_clist_new;
if ((!valid) && (!tar_clist_new.empty())) {
@ -998,6 +1089,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ +
left_node_cost->communication_forward_ + right_edge_cost->communication_forward_;
double new_commu_without =
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
@ -1009,6 +1102,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
new_cost->communication_with_partial_para_ =
new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without);
new_cost->memory_with_reuse_ = new_memory;
new_cost->communication_forward_ = new_commu_forward;
left_node_clist_new->emplace_back(std::move(new_cost));
}
}
@ -1079,7 +1173,7 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
&left_node_clist_new);
}
}
SimplifyForDreasingCommunicationWithPartialPara(&left_node_clist_new);
Simplify(&left_node_clist_new);
// Set the new costlist w.r.t the strategy
left_node_stra_cost->cost_list = left_node_clist_new;
if ((!valid) && (!left_node_clist_new.empty())) {
@ -1112,19 +1206,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
double computation_cost = merged_node_cost->computation_cost_,
memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
commu_without = merged_node_cost->communication_without_parameter_;
commu_without = merged_node_cost->communication_without_parameter_,
commu_forward = merged_node_cost->communication_forward_;
for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
if (i == 0) {
computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_;
commu_without += succ_edges_costs[i]->communication_without_parameter_ +
succ_nodes_costs[i]->communication_without_parameter_;
} else {
computation_cost += succ_edges_costs[i]->computation_cost_;
memory_cost += succ_edges_costs[i]->memory_with_reuse_;
commu_cost += succ_edges_costs[i]->communication_cost_;
commu_forward += succ_edges_costs[i]->communication_forward_;
commu_without += succ_edges_costs[i]->communication_without_parameter_;
}
}
@ -1135,6 +1232,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
new_cost->communication_without_parameter_ = commu_without;
new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
new_cost->memory_with_reuse_ = memory_cost;
new_cost->communication_forward_ = commu_forward;
first_succ_node_clist_new->emplace_back(std::move(new_cost));
}
}
@ -1220,7 +1318,7 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
merged_op_stra, merged_op_clist, &first_succ_node_clist_new);
}
SimplifyForDreasingCommunicationWithPartialPara(&first_succ_node_clist_new);
Simplify(&first_succ_node_clist_new);
// Set the new costlist w.r.t the strategy
first_succ_node_stra_cost->cost_list = first_succ_node_clist_new;
if ((!valid) && (!first_succ_node_clist_new.empty())) {

View File

@ -45,6 +45,9 @@ namespace parallel {
#define DEFAULT_FULLY_USE_DEVICES true
#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
#define DEFAULT_IS_MULTI_SUBGRAPHS false
#define DEFAULT_RUN_PHASE 0
#define TRAINING_PHASE 0
#define INFERENCE_PHASE 1
class CostGraph;
using CostGraphPtr = std::shared_ptr<CostGraph>;
@ -60,6 +63,8 @@ extern bool TENSOR_SLICE_ALIGNMENT_ENABLE;
extern size_t TENSOR_SLICE_ALIGNMENT_SIZE;
extern bool FULLY_USE_DEVICES;
extern bool ELEMENTWISE_OP_STRA_FOLLOW;
extern bool MULTI_SUBGRAPHS;
extern int32_t RUN_PHASE;
class CostGraph {
// 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
@ -98,7 +103,7 @@ class CostGraph {
CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v);
CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u);
CostPtr SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory);
CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory);
CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory);
CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory);
Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &);

View File

@ -47,6 +47,7 @@ void CostModelContext::ResetCostModel() {
costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST;
costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS;
is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS;
run_phase_ = DEFAULT_RUN_PHASE;
costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM;
costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES;
costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT;
@ -125,5 +126,7 @@ void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_
void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) {
elementwise_stra_follow_ = elementwise_follow;
}
void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; }
} // namespace parallel
} // namespace mindspore

View File

@ -113,6 +113,9 @@ class CostModelContext {
void set_elementwise_stra_follow(bool);
bool elementwise_stra_follow() const { return elementwise_stra_follow_; }
void set_run_phase(int32_t);
int32_t run_phase() const { return run_phase_; }
private:
CostModelContext();
static std::shared_ptr<CostModelContext> cm_context_inst_;
@ -141,8 +144,11 @@ class CostModelContext {
// COST_MODEL_COMMUNI_BIAS
double costmodel_communi_bias_;
// MULTI_SUBGRAPHS
bool is_multi_subgraphs_;
int32_t run_phase_; // 0: 'training', 1: 'inference'
int32_t costmodel_allreduce_fusion_algorithm_;
int32_t costmodel_allreduce_fusion_times_;

View File

@ -610,6 +610,7 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &
<< ", communication_with_partial_para_: " << result->communication_with_partial_para_;
// refine communication cost calculation for practice
RefineForPracticalCost(result, false);
result->communication_forward_ = result->communication_without_parameter_;
std::shared_ptr<StrategyWithCost> swc =
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);

View File

@ -1049,6 +1049,7 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
BreakingTiesForPerferringDataParallel(strategy, result);
// refine communication cost calculation for practice
RefineForPracticalCost(result, false);
result->communication_forward_ = result->communication_without_parameter_;
std::shared_ptr<StrategyWithCost> swc =
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);

View File

@ -69,16 +69,16 @@ class TensorRedistribution {
RankList dev_list_;
OperatorList operator_list_;
bool reshape_flag_;
// communication cost
// communication cost, which is the sum of forward communication cost and backward communication cost
double comm_cost_;
// forward communication cost
double forward_comm_cost_;
// backward communication cost
double backward_comm_cost_;
// computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the
// inputs.
// inputs. This is calculated ONLY for forward phase.
double computation_cost_;
// memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is
// memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is
// calculated by the outputs.
double memory_cost_;
bool construct_op_flag_;

View File

@ -228,6 +228,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Get the parameter cost_model_communi_bias of the DP algorithm.")
.def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.")
.def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.")
.def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.")
.def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.")
.def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm,
"Set the parameter gradient AllReduce fusion algorithm.")
.def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm,

View File

@ -239,6 +239,33 @@ class _CostModelContext:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_multi_subgraphs()
def set_run_phase(self, phase):
"""
Set the flag of running phase: training (0) or inference (1)
Args:
phase (int): A parameter indicating which phase is running.
Raises:
ValueError: If context handle is none, or phase is not in {0, 1}.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
if phase not in (0, 1):
raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase))
self._context_handle.set_run_phase(phase)
def get_run_phase(self):
"""
Get the flag of running phase.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_run_phase()
def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
"""
Set costmodel allreduce fusion algorithm.
@ -453,6 +480,7 @@ set_cost_model_context_func_map = {
"costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().set_multi_subgraphs,
"run_phase": cost_model_context().set_run_phase,
"costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
"costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
@ -473,7 +501,8 @@ get_cost_model_context_func_map = {
"costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
"costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().get_multi_subgraphs(),
"multi_subgraphs": cost_model_context().get_multi_subgraphs,
"run_phase": cost_model_context().get_run_phase,
"costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
"costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
@ -488,7 +517,7 @@ get_cost_model_context_func_map = {
@args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
multi_subgraphs=bool,
multi_subgraphs=bool, run_phase=int,
costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
costmodel_allreduce_fusion_allreduce_inherent_time=float,
@ -510,6 +539,7 @@ def set_cost_model_context(**kwargs):
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
0: bypass allreduce fusion;
1: only use backward computation time to group allreduce;

View File

@ -371,7 +371,7 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) {
ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS);
cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2);
auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2);
cost_graph.SelectCostWithMemoryConstraint(cost_list, cost_graph.GetDeviceMemory());
cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory());
}
TEST_F(TestCostGraph, test_EliminationOp) {

View File

@ -14,15 +14,21 @@
import mindspore.context as context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.parallel._cost_model_context import reset_cost_model_context
from mindspore.parallel.algo_parameter_config import reset_algo_parameters
from mindspore.parallel._utils import _reset_op_id
def setup_module(module):
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
reset_cost_model_context()
reset_algo_parameters()
_reset_op_id()
def teardown_module():
context.reset_auto_parallel_context()
reset_cost_model_context()
reset_algo_parameters()
_reset_op_id()

View File

@ -0,0 +1,36 @@
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn import Momentum
from mindspore.parallel._cost_model_context import set_cost_model_context
class Net(nn.Cell):
def __init__(self, input_ch, out_ch):
super(Net, self).__init__()
self.dense = nn.Dense(input_ch, out_ch)
self.relu = P.ReLU()
def construct(self, x):
x = self.dense(x)
x = self.relu(x)
return x
def test_inference_phase():
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
set_cost_model_context(run_phase=1)
net = Net(512, 128)
predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.001)
label = Tensor(np.ones([64, 128]).astype(np.float32))
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
train_network.set_train()
output = train_network(predict, label)