forked from mindspore-Ecosystem/mindspore
!232 [Auto parallel] Model the memory_cost in cost model
Merge pull request !232 from Xiaoda/model-memory-cost-in-auto-parallel
This commit is contained in:
commit
1ab430072e
|
@ -207,15 +207,13 @@ struct ContractEliminationDecision : public Decision {
|
|||
*/
|
||||
struct TriangleEliminationDecision : public Decision {
|
||||
TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost,
|
||||
StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra, CostPtr r_node_cost)
|
||||
StrategyPtr left_stra, CostPtr l_node_cost)
|
||||
: eliminated_op_strategy_(std::move(elimi_stra)),
|
||||
eliminated_op_cost_(std::move(elimi_op_cost)),
|
||||
left_edge_cost_(std::move(l_edge_cost)),
|
||||
right_edge_cost_(std::move(r_edge_cost)),
|
||||
left_node_strategy_(std::move(left_stra)),
|
||||
left_node_cost_(std::move(l_node_cost)),
|
||||
right_node_strategy_(std::move(right_stra)),
|
||||
right_node_cost_(std::move(r_node_cost)) {
|
||||
left_node_cost_(std::move(l_node_cost)) {
|
||||
type_ = DecisionType::TRIANGLE_ELIMINATION;
|
||||
}
|
||||
|
||||
|
@ -225,8 +223,6 @@ struct TriangleEliminationDecision : public Decision {
|
|||
CostPtr right_edge_cost_;
|
||||
StrategyPtr left_node_strategy_;
|
||||
CostPtr left_node_cost_;
|
||||
StrategyPtr right_node_strategy_;
|
||||
CostPtr right_node_cost_;
|
||||
MS_DECLARE_PARENT(TriangleEliminationDecision, Decision);
|
||||
};
|
||||
|
||||
|
|
|
@ -76,7 +76,6 @@ Status GetStrategy(const CostGraphPtr& graph) {
|
|||
auto l_r_edge = triangle_pair.second;
|
||||
|
||||
auto left_node = l_r_edge->prev_operator();
|
||||
auto right_node = l_r_edge->next_operator();
|
||||
auto left_edge = eliminated_node->GetAliveSuccEdges()[0];
|
||||
auto right_edge = eliminated_node->GetAliveSuccEdges()[1];
|
||||
MS_EXCEPTION_IF_NULL(left_edge);
|
||||
|
@ -86,8 +85,7 @@ Status GetStrategy(const CostGraphPtr& graph) {
|
|||
right_edge = tmp;
|
||||
}
|
||||
auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge);
|
||||
auto elimi =
|
||||
std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node);
|
||||
auto elimi = std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
}
|
||||
auto star_center = graph->CheckStarElimination();
|
||||
|
@ -183,14 +181,13 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
|||
auto left_edge = elimination->left_edge_;
|
||||
auto eliminated_node = elimination->eliminated_node_;
|
||||
auto right_edge = elimination->right_edge_;
|
||||
auto right_node = elimination->right_node_;
|
||||
auto decision = left_node->selected_cost()->decision_ptr_->cast<TriangleEliminationDecisionPtr>();
|
||||
|
||||
eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_);
|
||||
left_edge->set_selected_cost(decision->left_edge_cost_);
|
||||
right_edge->set_selected_cost(decision->right_edge_cost_);
|
||||
// Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy.
|
||||
left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_);
|
||||
right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_);
|
||||
MS_LOG(INFO) << "Recover triangleElimination succeeded.";
|
||||
} else if ((*rit)->isa<StarElimination>()) {
|
||||
auto elimination = (*rit)->cast<StarEliminationPtr>();
|
||||
|
@ -204,9 +201,11 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
|||
for (size_t i = 0; i < succ_edges.size(); ++i) {
|
||||
succ_edges[i]->set_selected_cost(decision->succ_edges_cost_list_[i]);
|
||||
}
|
||||
for (size_t j = 0; j < succ_nodes.size(); ++j) {
|
||||
succ_nodes[j]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[j], decision->succ_ops_cost_list_[j]);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(succ_nodes[0]);
|
||||
MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]);
|
||||
MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]);
|
||||
// Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy.
|
||||
succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]);
|
||||
MS_LOG(INFO) << "Recover starElimination succeeded.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unknown Elimination type.";
|
||||
|
|
|
@ -102,20 +102,17 @@ struct ContractElimination : public Elimination {
|
|||
|
||||
// Triangle Elimination
|
||||
struct TriangleElimination : public Elimination {
|
||||
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge,
|
||||
OperatorInfoPtr r_node)
|
||||
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge)
|
||||
: Elimination(nullptr, Elimination::EliminationType::TRIANGLE),
|
||||
eliminated_node_(std::move(elim_node)),
|
||||
left_edge_(std::move(l_edge)),
|
||||
left_node_(std::move(l_node)),
|
||||
right_edge_(std::move(r_edge)),
|
||||
right_node_(std::move(r_node)) {}
|
||||
right_edge_(std::move(r_edge)) {}
|
||||
|
||||
OperatorInfoPtr eliminated_node_;
|
||||
EdgePtr left_edge_;
|
||||
OperatorInfoPtr left_node_;
|
||||
EdgePtr right_edge_;
|
||||
OperatorInfoPtr right_node_;
|
||||
MS_DECLARE_PARENT(TriangleElimination, Elimination);
|
||||
};
|
||||
|
||||
|
|
|
@ -119,6 +119,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co
|
|||
double forward_comm_cost = tensor_redistribution.forward_comm_cost();
|
||||
double backward_comm_cost = tensor_redistribution.backward_comm_cost();
|
||||
double computation_cost = tensor_redistribution.computation_cost();
|
||||
double mem_cost = tensor_redistribution.memory_cost();
|
||||
|
||||
// Now AllGather, ReduceScatter, AlltoAll don't support bool type
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
@ -134,6 +135,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co
|
|||
COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
|
||||
(*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
|
||||
(*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
|
||||
(*cost)->memory_with_reuse_ = mem_cost;
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -158,8 +160,8 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
|
|||
(void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
|
||||
|
||||
CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
|
||||
std::function<void(size_t, double, double, double)> recursive =
|
||||
[&](size_t k, double computation, double communication, double communication_without_para) {
|
||||
std::function<void(size_t, double, double, double, double)> recursive =
|
||||
[&](size_t k, double computation, double memory, double communication, double communication_without_para) {
|
||||
if (k == edges.size()) {
|
||||
auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
|
||||
CostPtr new_cost = std::make_shared<Cost>(computation, communication);
|
||||
|
@ -167,6 +169,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
|
|||
new_cost->communication_without_parameter_ = communication_without_para;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
new_cost->memory_with_reuse_ = memory;
|
||||
new_cost->decision_ptr_ = decision;
|
||||
result.push_back(new_cost);
|
||||
return;
|
||||
|
@ -174,11 +177,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
|
|||
for (auto& c : all_cost_list[k]) {
|
||||
MS_EXCEPTION_IF_NULL(c);
|
||||
selected_cost_list[k] = c;
|
||||
recursive(k + 1, computation + c->computation_cost_, communication + c->communication_cost_,
|
||||
recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
|
||||
communication + c->communication_cost_,
|
||||
communication_without_para + c->communication_without_parameter_);
|
||||
}
|
||||
};
|
||||
recursive(0, 0, 0, 0);
|
||||
recursive(0, 0.0, 0.0, 0.0, 0.0);
|
||||
SimplifyForDreasingCommunicationWithPartialPara(&result);
|
||||
return result;
|
||||
}
|
||||
|
@ -218,6 +222,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
|
|||
double communication_without_para = left_cost->communication_without_parameter_ +
|
||||
middle_cost->communication_without_parameter_ +
|
||||
right_cost->communication_without_parameter_;
|
||||
double memory_cost =
|
||||
left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_;
|
||||
|
||||
auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
|
||||
auto cost = std::make_shared<Cost>(computation, communication, decision);
|
||||
|
@ -225,6 +231,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
|
|||
cost->communication_without_parameter_ = communication_without_para;
|
||||
cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
cost->memory_with_reuse_ = memory_cost;
|
||||
ret_cost_list->emplace_back(std::move(cost));
|
||||
}
|
||||
}
|
||||
|
@ -267,5 +274,24 @@ void Edge::OpEliminationSetNewCost(const EdgePtr& e1, const OperatorInfoPtr& op,
|
|||
MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
|
||||
}
|
||||
}
|
||||
|
||||
Status Edge::CalculateMemoryCost() {
|
||||
if (is_output_parameter_involve_ == -1) {
|
||||
MS_LOG(ERROR) << "is_output_parameter_involve_ is unset.";
|
||||
return FAILED;
|
||||
}
|
||||
if (is_output_parameter_involve_ == 0) {
|
||||
// In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is
|
||||
// unnecessary to keep them in memory.
|
||||
for (auto& cost_kv : cost_map_) {
|
||||
auto& cost_v = cost_kv.second;
|
||||
if (!cost_v.empty()) {
|
||||
cost_v[0]->memory_with_reuse_ = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,7 +133,7 @@ class Edge {
|
|||
void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; }
|
||||
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
|
||||
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
|
||||
Status CalculateMemoryCost() const { return SUCCESS; }
|
||||
Status CalculateMemoryCost();
|
||||
|
||||
private:
|
||||
std::string edge_name_;
|
||||
|
|
|
@ -248,6 +248,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std::
|
|||
MS_EXCEPTION_IF_NULL(cost2);
|
||||
MS_EXCEPTION_IF_NULL(cost3);
|
||||
double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
|
||||
double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
|
||||
double commmunication =
|
||||
cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
|
||||
double communication_without_para = cost1->communication_without_parameter_ +
|
||||
|
@ -260,6 +261,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std::
|
|||
cost->communication_without_parameter_ = communication_without_para;
|
||||
cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (commmunication - communication_without_para);
|
||||
cost->memory_with_reuse_ = memory;
|
||||
ret.push_back(cost);
|
||||
}
|
||||
}
|
||||
|
@ -288,6 +290,7 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) {
|
|||
new_cost->communication_with_partial_para_ =
|
||||
cost1->communication_without_parameter_ +
|
||||
COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_);
|
||||
new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
|
||||
ret.push_back(new_cost);
|
||||
}
|
||||
}
|
||||
|
@ -297,9 +300,14 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) {
|
|||
}
|
||||
|
||||
CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory) {
|
||||
if (cost_list.empty() || cost_list[0]->computation_cost_ >= memory) {
|
||||
return nullptr;
|
||||
CostPtrList after_mem_filter;
|
||||
// Filter out the valid costs
|
||||
for (auto& a_cost : cost_list) {
|
||||
if (a_cost->memory_with_reuse_ <= memory) {
|
||||
after_mem_filter.emplace_back(std::move(a_cost));
|
||||
}
|
||||
}
|
||||
|
||||
std::function<CostPtr(CostPtr, const CostPtr&)> LocalCompare = [&](CostPtr init, const CostPtr& cost_x) {
|
||||
MS_EXCEPTION_IF_NULL(cost_x);
|
||||
if (init == nullptr || cost_x->computation_cost_ < memory) {
|
||||
|
@ -308,7 +316,7 @@ CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list,
|
|||
return init;
|
||||
};
|
||||
CostPtr ret = nullptr;
|
||||
return std::accumulate(cost_list.begin(), cost_list.end(), ret, LocalCompare);
|
||||
return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare);
|
||||
}
|
||||
|
||||
CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory) {
|
||||
|
@ -318,36 +326,46 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d
|
|||
MS_LOG(ERROR) << "Final cost list is null.";
|
||||
return nullptr;
|
||||
}
|
||||
CostPtr ret = cost_list[0];
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
if (ret->computation_cost_ >= memory) {
|
||||
MS_LOG(ERROR) << "No available cost; the minimum cost is " << ret->computation_cost_
|
||||
CostPtrList after_mem_filter;
|
||||
double minimum_memory = DBL_MAX;
|
||||
// Filter out the valid costs.
|
||||
for (auto& a_cost : cost_list) {
|
||||
if (a_cost->memory_with_reuse_ <= memory) {
|
||||
after_mem_filter.emplace_back(std::move(a_cost));
|
||||
} else if (a_cost->memory_with_reuse_ < minimum_memory) {
|
||||
minimum_memory = a_cost->memory_with_reuse_;
|
||||
}
|
||||
}
|
||||
if (after_mem_filter.empty()) {
|
||||
MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
|
||||
<< ", the memory capacity is: " << memory << ".";
|
||||
return nullptr;
|
||||
}
|
||||
// Init the returned value with first cost.
|
||||
CostPtr ret = after_mem_filter[0];
|
||||
|
||||
double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_;
|
||||
MS_LOG(INFO) << "minimum: " << minimum << ", computation_cost_: " << ret->computation_cost_
|
||||
MS_LOG(INFO) << "Cost 0: "
|
||||
<< "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
|
||||
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
|
||||
<< ", communication_cost_: " << ret->communication_cost_
|
||||
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
|
||||
for (size_t i = 1; i < cost_list.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(cost_list[i]);
|
||||
if (cost_list[i]->computation_cost_ >= memory) {
|
||||
MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_
|
||||
<< ", is larger than the memory capacity: " << memory << ".";
|
||||
break;
|
||||
}
|
||||
MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_
|
||||
<< ", communication_with_partial_para_: " << cost_list[i]->communication_with_partial_para_
|
||||
<< ", communication_cost_: " << cost_list[i]->communication_cost_
|
||||
<< ", communication_without_parameter_: " << cost_list[i]->communication_without_parameter_ << ".";
|
||||
auto tmp = costmodel_alpha_ * cost_list[i]->computation_cost_ +
|
||||
costmodel_beta_ * cost_list[i]->communication_with_partial_para_;
|
||||
MS_LOG(INFO) << "tmp: " << tmp;
|
||||
MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
|
||||
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
|
||||
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
|
||||
<< ", computation_cost_: " << after_mem_filter[i]->computation_cost_
|
||||
<< ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
|
||||
<< ", communication_cost_: " << after_mem_filter[i]->communication_cost_
|
||||
<< ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
|
||||
<< ".";
|
||||
auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
|
||||
costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_;
|
||||
MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
|
||||
if (minimum > tmp) {
|
||||
minimum = tmp;
|
||||
ret = cost_list[i];
|
||||
MS_LOG(INFO) << "selected: " << i;
|
||||
ret = after_mem_filter[i];
|
||||
MS_LOG(INFO) << "Selected: " << i;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
|
@ -356,17 +374,21 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d
|
|||
CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList>& all_cost_list,
|
||||
double available_memory) {
|
||||
CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
|
||||
double minimum = 0.0, total_memory = 0.0;
|
||||
double minimum = DBL_MAX, total_memory = 0.0;
|
||||
CostPtrList ret(all_cost_list.size(), nullptr);
|
||||
// Check whether valid costs exist.
|
||||
for (size_t i = 0; i < all_cost_list.size(); ++i) {
|
||||
if (all_cost_list[i][0] == nullptr) {
|
||||
MS_LOG(ERROR) << "The cost list " << i << " is empty.";
|
||||
return ret;
|
||||
} else {
|
||||
total_memory += all_cost_list[i][0]->computation_cost_;
|
||||
minimum += costmodel_alpha_ * all_cost_list[i][0]->computation_cost_ +
|
||||
costmodel_beta_ * all_cost_list[i][0]->communication_with_partial_para_;
|
||||
ret[i] = all_cost_list[i][0];
|
||||
double memory_i_cost = DBL_MAX;
|
||||
for (size_t j = 0; j < all_cost_list[i].size(); ++j) {
|
||||
if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) {
|
||||
memory_i_cost = all_cost_list[i][j]->memory_with_reuse_;
|
||||
}
|
||||
}
|
||||
total_memory += memory_i_cost;
|
||||
}
|
||||
}
|
||||
if (total_memory >= available_memory) {
|
||||
|
@ -381,7 +403,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect
|
|||
double tmp_memory = 0.0, tmp_minimum = 0.0;
|
||||
for (size_t i = 0; i < selected_cost_list.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(selected_cost_list[i]);
|
||||
tmp_memory += selected_cost_list[i]->computation_cost_;
|
||||
tmp_memory += selected_cost_list[i]->memory_with_reuse_;
|
||||
tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ +
|
||||
costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_;
|
||||
}
|
||||
|
@ -816,6 +838,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
|
|||
auto& tar_cost = tar_cost_list[k];
|
||||
MS_EXCEPTION_IF_NULL(tar_cost);
|
||||
double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
|
||||
double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
|
||||
double communication =
|
||||
op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
|
||||
double communication_without_para = op_cost->communication_without_parameter_ +
|
||||
|
@ -829,6 +852,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
|
|||
new_cost->communication_without_parameter_ = communication_without_para;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
new_cost->memory_with_reuse_ = memory;
|
||||
MS_EXCEPTION_IF_NULL(tar_cost_list_new);
|
||||
tar_cost_list_new->emplace_back(std::move(new_cost));
|
||||
}
|
||||
|
@ -894,6 +918,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
|
|||
MS_EXCEPTION_IF_NULL(tar_cost);
|
||||
double computation =
|
||||
contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
|
||||
double memory =
|
||||
contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
|
||||
double communication =
|
||||
contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
|
||||
double communication_without_para = contract_op_cost->communication_without_parameter_ +
|
||||
|
@ -906,6 +932,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
|
|||
new_cost->communication_without_parameter_ = communication_without_para;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
new_cost->memory_with_reuse_ = memory;
|
||||
tar_cost_list_new->emplace_back(std::move(new_cost));
|
||||
}
|
||||
}
|
||||
|
@ -966,23 +993,22 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
|
|||
for (auto& left_node_cost : left_node_clist_origin) {
|
||||
MS_EXCEPTION_IF_NULL(left_node_cost);
|
||||
double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ +
|
||||
left_node_cost->computation_cost_ + right_edge_cost->computation_cost_ +
|
||||
right_op_cost->computation_cost_;
|
||||
left_node_cost->computation_cost_ + right_edge_cost->computation_cost_;
|
||||
double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ +
|
||||
left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
|
||||
double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
|
||||
left_node_cost->communication_cost_ + right_edge_cost->communication_cost_ +
|
||||
right_op_cost->communication_cost_;
|
||||
left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
|
||||
double new_commu_without =
|
||||
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
|
||||
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_ +
|
||||
right_op_cost->communication_without_parameter_;
|
||||
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
|
||||
|
||||
auto decision =
|
||||
std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost,
|
||||
left_op_stra, left_node_cost, right_op_stra, right_op_cost);
|
||||
auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost,
|
||||
right_edge_cost, left_op_stra, left_node_cost);
|
||||
auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
|
||||
new_cost->communication_without_parameter_ = new_commu_without;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without);
|
||||
new_cost->memory_with_reuse_ = new_memory;
|
||||
left_node_clist_new->emplace_back(std::move(new_cost));
|
||||
}
|
||||
}
|
||||
|
@ -1085,14 +1111,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n
|
|||
succ_nodes_costs[0] = first_succ_node_cost;
|
||||
|
||||
double computation_cost = merged_node_cost->computation_cost_,
|
||||
commu_cost = merged_node_cost->communication_cost_,
|
||||
memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
|
||||
commu_without = merged_node_cost->communication_without_parameter_;
|
||||
for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
|
||||
computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
|
||||
commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
|
||||
commu_without += succ_edges_costs[i]->communication_without_parameter_ +
|
||||
succ_nodes_costs[i]->communication_without_parameter_;
|
||||
if (i == 0) {
|
||||
computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
|
||||
memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
|
||||
commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
|
||||
commu_without += succ_edges_costs[i]->communication_without_parameter_ +
|
||||
succ_nodes_costs[i]->communication_without_parameter_;
|
||||
} else {
|
||||
computation_cost += succ_edges_costs[i]->computation_cost_;
|
||||
memory_cost += succ_edges_costs[i]->memory_with_reuse_;
|
||||
commu_cost += succ_edges_costs[i]->communication_cost_;
|
||||
commu_without += succ_edges_costs[i]->communication_without_parameter_;
|
||||
}
|
||||
}
|
||||
|
||||
auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs,
|
||||
|
@ -1100,6 +1134,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n
|
|||
auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision);
|
||||
new_cost->communication_without_parameter_ = commu_without;
|
||||
new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
|
||||
new_cost->memory_with_reuse_ = memory_cost;
|
||||
first_succ_node_clist_new->emplace_back(std::move(new_cost));
|
||||
}
|
||||
}
|
||||
|
@ -1259,5 +1294,35 @@ OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string& p_name) c
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
Status CostGraph::CorrectOpsMemoryCost() {
|
||||
for (auto& one_op : ops_) {
|
||||
if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) {
|
||||
if (one_op->GetAliveSuccEdges().size() > 1) {
|
||||
// Filter out the case when the TmpIdentity being used by multiple operators
|
||||
std::map<size_t, int> output_count;
|
||||
for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
|
||||
auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
|
||||
output_count[output_index]++;
|
||||
}
|
||||
for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
|
||||
auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
|
||||
if (output_count[output_index] <= 1) {
|
||||
continue;
|
||||
}
|
||||
auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator();
|
||||
MS_EXCEPTION_IF_NULL(next_op);
|
||||
auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index();
|
||||
if (next_op->CorrectMemoryCost(input_index) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name()
|
||||
<< ", the output_index: " << output_index << ", the input_index: " << input_index << ".";
|
||||
return FAILED;
|
||||
}
|
||||
output_count[output_index]--;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -187,6 +187,9 @@ class CostGraph {
|
|||
size_t GetNumPairs() const { return edges_.size(); }
|
||||
Status InitSelectedStrategy();
|
||||
OperatorInfoPtr FindTmpIdentityByParameterName(std::string&) const;
|
||||
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
|
||||
// once (instead of multiple times), this method is used to correct this.
|
||||
Status CorrectOpsMemoryCost();
|
||||
// Needed by rec_parser
|
||||
void add_inputs_tensor_name(const std::vector<std::string>& inputs_tensor_name) {
|
||||
inputs_tensor_name_list_.push_back(inputs_tensor_name);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "parallel/auto_parallel/operator_costmodel.h"
|
||||
|
||||
#include <random>
|
||||
#include <algorithm>
|
||||
#include "parallel/device_matrix.h"
|
||||
#include "parallel/tensor_layout/tensor_redistribution.h"
|
||||
|
||||
|
@ -24,12 +25,44 @@ namespace mindspore {
|
|||
namespace parallel {
|
||||
void OperatorCost::set_is_parameter(const std::vector<bool>& is_parameter) { is_parameter_ = is_parameter; }
|
||||
|
||||
void OperatorCost::set_is_parameter_involve(const std::vector<bool>& is_parameter_inv) {
|
||||
is_parameter_involve_ = is_parameter_inv;
|
||||
}
|
||||
|
||||
void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; }
|
||||
|
||||
void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t>& input_lengths,
|
||||
const std::vector<size_t>& output_lengths) {
|
||||
inputs_type_lengths_ = input_lengths;
|
||||
outputs_type_lengths_ = output_lengths;
|
||||
}
|
||||
|
||||
double OperatorCost::GetMemoryCost(const std::vector<TensorInfo>& inputs,
|
||||
const std::vector<TensorInfo>& outputs) const {
|
||||
double result = 0.0;
|
||||
if (output_parameter_involve_ == 1) {
|
||||
// When this operator has multiple outputs, they all contributes to the memory.
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
|
||||
}
|
||||
bool is_any_para_inv =
|
||||
std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; });
|
||||
if (is_any_para_inv) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_parameter_[i]) {
|
||||
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
|
||||
} else if (inputs_related_ && (!is_parameter_involve_[i])) {
|
||||
// When the inputs of this operator are related, and they are not parameter-involved, then they are included
|
||||
// in the memory cost.
|
||||
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// return the per device communication cost in the forward phase.
|
||||
double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
const int32_t&) const {
|
||||
|
@ -72,11 +105,11 @@ double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, co
|
|||
return result;
|
||||
}
|
||||
|
||||
// Return the per device memory cost in the forward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs,
|
||||
const std::vector<TensorInfo>& outputs, const int32_t&) const {
|
||||
// In forward phase, the memory cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
|
||||
// In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
|
||||
double result = 0.0;
|
||||
TensorInfo output0 = outputs[0];
|
||||
Shape input0_slice_shape = inputs[0].slice_shape();
|
||||
|
@ -91,11 +124,11 @@ double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo>& inpu
|
|||
return result;
|
||||
}
|
||||
|
||||
// Return the per device memory cost in the forward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
|
||||
const int32_t& stage_id) const {
|
||||
// In backward phase, the memory cost = (0 or 1) allreduce(slice(B))
|
||||
// In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
|
||||
double result = 0.0;
|
||||
if (is_parameter_[1]) {
|
||||
TensorInfo input1 = inputs[1]; // tensor B
|
||||
|
@ -145,7 +178,7 @@ double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs
|
|||
return result;
|
||||
}
|
||||
|
||||
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
|
||||
const int32_t&) const {
|
||||
|
@ -154,7 +187,7 @@ double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo>&
|
|||
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
}
|
||||
|
||||
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double ActivationCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
|
||||
const int32_t&) const {
|
||||
|
@ -189,17 +222,17 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, c
|
|||
return result;
|
||||
}
|
||||
|
||||
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
|
||||
const int32_t&) const {
|
||||
// In the forward phase, the memory cost = slice(A)
|
||||
// In the forward phase, the computation cost = slice(A)
|
||||
TensorInfo input0 = inputs[0];
|
||||
Shape input0_slice_shape = input0.slice_shape();
|
||||
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
}
|
||||
|
||||
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
|
||||
const std::vector<mindspore::parallel::TensorInfo>&,
|
||||
|
@ -221,17 +254,15 @@ double TmpIdentityCost::GetBackwardCommCost(const std::vector<mindspore::paralle
|
|||
return 0.0;
|
||||
}
|
||||
|
||||
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs,
|
||||
double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
|
||||
const std::vector<mindspore::parallel::TensorInfo>&,
|
||||
const int32_t&) const {
|
||||
TensorInfo input0_info = inputs[0];
|
||||
Shape input0_slice_shape = input0_info.slice_shape();
|
||||
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
|
||||
const std::vector<mindspore::parallel::TensorInfo>&,
|
||||
|
@ -239,6 +270,11 @@ double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::
|
|||
return 0.0;
|
||||
}
|
||||
|
||||
// Return the per device PEAK memory cost contributed by this operator in a training iteration.
|
||||
double TmpIdentityCost::GetMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&) const {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs,
|
||||
const std::vector<mindspore::parallel::TensorInfo>&,
|
||||
const int32_t&) const {
|
||||
|
@ -284,11 +320,11 @@ double PReLUCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, con
|
|||
return result;
|
||||
}
|
||||
|
||||
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
|
||||
const int32_t&) const {
|
||||
// In forward phase, the memory cost = slice(A) + slice(B)
|
||||
// In forward phase, the computation cost = slice(A) + slice(B)
|
||||
Shape input0_slice_shape = inputs[0].slice_shape();
|
||||
Shape input1_slice_shape = inputs[1].slice_shape();
|
||||
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
|
||||
|
@ -296,12 +332,12 @@ double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo>& input
|
|||
return result;
|
||||
}
|
||||
|
||||
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs,
|
||||
const std::vector<mindspore::parallel::TensorInfo>&,
|
||||
const int32_t& stage_id) const {
|
||||
// In backward phase, the memory cost = (0 or 1) allreduce(slice(B))
|
||||
// In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
|
||||
double result = 0.0;
|
||||
if (is_parameter_[1]) {
|
||||
TensorInfo input1 = inputs[1]; // tensor B
|
||||
|
@ -337,16 +373,16 @@ double OneHotCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const std
|
|||
return 0.0;
|
||||
}
|
||||
|
||||
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double OneHotCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
|
||||
const int32_t&) const {
|
||||
// In onehot's forward phase, the memory cost = slice(A)
|
||||
// In onehot's forward phase, the computation cost = slice(A)
|
||||
Shape input0_slice_shape = inputs[0].slice_shape();
|
||||
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
}
|
||||
|
||||
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
|
||||
const int32_t&) const {
|
||||
|
@ -367,12 +403,12 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector<
|
|||
return 0.0;
|
||||
}
|
||||
|
||||
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs,
|
||||
const std::vector<TensorInfo>&,
|
||||
const int32_t&) const {
|
||||
// In forward phase, the memory cost = slice(A) + slice(B)
|
||||
// In forward phase, the computation cost = slice(A) + slice(B)
|
||||
Shape input0_slice_shape = inputs[0].slice_shape();
|
||||
Shape input1_slice_shape = inputs[1].slice_shape();
|
||||
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
|
||||
|
@ -380,7 +416,7 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::v
|
|||
return result;
|
||||
}
|
||||
|
||||
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector<TensorInfo>&,
|
||||
const std::vector<TensorInfo>&,
|
||||
|
@ -410,7 +446,7 @@ double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const st
|
|||
return 0.0;
|
||||
}
|
||||
|
||||
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs,
|
||||
const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const {
|
||||
|
@ -427,7 +463,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inp
|
|||
return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost());
|
||||
}
|
||||
|
||||
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes
|
||||
// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
|
||||
const std::vector<mindspore::parallel::TensorInfo>&,
|
||||
|
|
|
@ -43,10 +43,20 @@ double ListProduct(std::vector<T> vec) {
|
|||
// entries timing the length of each entry's data type
|
||||
class OperatorCost {
|
||||
public:
|
||||
OperatorCost() {
|
||||
explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) {
|
||||
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
|
||||
for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
|
||||
is_parameter_.push_back(false);
|
||||
is_parameter_involve_.push_back(false);
|
||||
inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
|
||||
outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
|
||||
}
|
||||
}
|
||||
OperatorCost() : inputs_related_(false) {
|
||||
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
|
||||
for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
|
||||
is_parameter_.push_back(false);
|
||||
is_parameter_involve_.push_back(false);
|
||||
inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
|
||||
outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
|
||||
}
|
||||
|
@ -54,6 +64,8 @@ class OperatorCost {
|
|||
virtual ~OperatorCost() = default;
|
||||
|
||||
void set_is_parameter(const std::vector<bool>& is_parameter);
|
||||
void set_is_parameter_involve(const std::vector<bool>&);
|
||||
void set_output_parameter_involve(int);
|
||||
void SetInputAndOutputTypeLength(const std::vector<size_t>& input_lengths, const std::vector<size_t>& output_lengths);
|
||||
std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; }
|
||||
std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; }
|
||||
|
@ -72,8 +84,19 @@ class OperatorCost {
|
|||
const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0;
|
||||
virtual double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs,
|
||||
const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0;
|
||||
// per device PEAK memory cost in a training iteration
|
||||
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled),
|
||||
// plus necessary inputs.
|
||||
virtual double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const;
|
||||
|
||||
protected:
|
||||
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
|
||||
// pre-operator that has parameters as input.
|
||||
std::vector<bool> is_parameter_involve_;
|
||||
int output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
|
||||
// Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while
|
||||
// Mul's two inputs are dependent (related).
|
||||
bool inputs_related_;
|
||||
// for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
|
||||
std::vector<bool> is_parameter_;
|
||||
// for each input and output, the followings record the number of bytes of each element
|
||||
|
@ -85,7 +108,8 @@ using OperatorCostPtr = std::shared_ptr<OperatorCost>;
|
|||
|
||||
class MatMulCost : public OperatorCost {
|
||||
public:
|
||||
MatMulCost() = default;
|
||||
explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
MatMulCost() : OperatorCost(true) {}
|
||||
~MatMulCost() override = default;
|
||||
|
||||
// per device communication cost
|
||||
|
@ -108,12 +132,12 @@ class MatMulCost : public OperatorCost {
|
|||
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
const int32_t& stage_id) const override;
|
||||
};
|
||||
|
||||
using MatMulCostPtr = std::shared_ptr<MatMulCost>;
|
||||
|
||||
class ActivationCost : public OperatorCost {
|
||||
public:
|
||||
ActivationCost() = default;
|
||||
explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
ActivationCost() : OperatorCost(false) {}
|
||||
~ActivationCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
@ -133,14 +157,14 @@ class ActivationCost : public OperatorCost {
|
|||
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
const int32_t& stage_id) const override;
|
||||
};
|
||||
|
||||
using ActivationCostPtr = std::shared_ptr<ActivationCost>;
|
||||
using TransposeCost = ActivationCost;
|
||||
using TransposeCostPtr = std::shared_ptr<TransposeCost>;
|
||||
|
||||
class SoftmaxCost : public OperatorCost {
|
||||
public:
|
||||
SoftmaxCost() = default;
|
||||
explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
SoftmaxCost() : OperatorCost(false) {}
|
||||
~SoftmaxCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
@ -160,12 +184,12 @@ class SoftmaxCost : public OperatorCost {
|
|||
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
const int32_t&) const override;
|
||||
};
|
||||
|
||||
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
|
||||
|
||||
class TmpIdentityCost : public OperatorCost {
|
||||
public:
|
||||
TmpIdentityCost() = default;
|
||||
explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
TmpIdentityCost() : OperatorCost(false) {}
|
||||
~TmpIdentityCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
@ -184,12 +208,15 @@ class TmpIdentityCost : public OperatorCost {
|
|||
const int32_t& stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
const int32_t& stage_id) const override;
|
||||
// per device PEAK memory cost in a training iteration
|
||||
double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const override;
|
||||
};
|
||||
using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>;
|
||||
|
||||
class BatchParallelCost : public OperatorCost {
|
||||
public:
|
||||
BatchParallelCost() = default;
|
||||
explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
BatchParallelCost() : OperatorCost(false) {}
|
||||
~BatchParallelCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
@ -217,7 +244,8 @@ using BatchParallelCostPtr = std::shared_ptr<BatchParallelCost>;
|
|||
|
||||
class VirtualDatasetCost : public OperatorCost {
|
||||
public:
|
||||
VirtualDatasetCost() = default;
|
||||
explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
VirtualDatasetCost() : OperatorCost(false) {}
|
||||
~VirtualDatasetCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
@ -244,12 +272,17 @@ class VirtualDatasetCost : public OperatorCost {
|
|||
const int32_t&) const override {
|
||||
return 0.0;
|
||||
}
|
||||
// per device PEAK memory cost in a training iteration
|
||||
double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const override {
|
||||
return 0.0;
|
||||
}
|
||||
};
|
||||
using VirtualDatasetCostPtr = std::shared_ptr<VirtualDatasetCost>;
|
||||
|
||||
class GeneratorBaseCost : public OperatorCost {
|
||||
public:
|
||||
GeneratorBaseCost() = default;
|
||||
explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
GeneratorBaseCost() : OperatorCost(false) {}
|
||||
~GeneratorBaseCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
@ -283,7 +316,8 @@ using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>;
|
|||
|
||||
class PReLUCost : public OperatorCost {
|
||||
public:
|
||||
PReLUCost() = default;
|
||||
explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
PReLUCost() : OperatorCost(true) {}
|
||||
~PReLUCost() override = default;
|
||||
|
||||
// per device communication cost
|
||||
|
@ -310,7 +344,8 @@ using PReLUCostPtr = std::shared_ptr<PReLUCost>;
|
|||
|
||||
class OneHotCost : public OperatorCost {
|
||||
public:
|
||||
OneHotCost() = default;
|
||||
explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
OneHotCost() : OperatorCost(true) {}
|
||||
~OneHotCost() override = default;
|
||||
|
||||
// per device communication cost
|
||||
|
@ -337,7 +372,8 @@ using OneHotCostPtr = std::shared_ptr<OneHotCost>;
|
|||
|
||||
class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost {
|
||||
public:
|
||||
SoftmaxCrossEntropyWithLogitsCost() = default;
|
||||
explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {}
|
||||
~SoftmaxCrossEntropyWithLogitsCost() override = default;
|
||||
|
||||
// per device communication cost
|
||||
|
@ -364,7 +400,8 @@ using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr<SoftmaxCrossEntropy
|
|||
|
||||
class ReshapeCost : public OperatorCost {
|
||||
public:
|
||||
ReshapeCost() = default;
|
||||
explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
ReshapeCost() : OperatorCost(true) {}
|
||||
|
||||
~ReshapeCost() override = default;
|
||||
|
||||
|
@ -396,7 +433,8 @@ using ReshapeCostPtr = std::shared_ptr<ReshapeCost>;
|
|||
|
||||
class ArithmeticCost : public OperatorCost {
|
||||
public:
|
||||
ArithmeticCost() = default;
|
||||
explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
ArithmeticCost() : OperatorCost(false) {}
|
||||
~ArithmeticCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
@ -425,7 +463,8 @@ using BiasAddCostPtr = std::shared_ptr<BiasAddCost>;
|
|||
|
||||
class ReduceMethodCost : public OperatorCost {
|
||||
public:
|
||||
ReduceMethodCost() = default;
|
||||
explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
ReduceMethodCost() : OperatorCost(true) {}
|
||||
~ReduceMethodCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
@ -455,7 +494,8 @@ using ReduceMethodCostPtr = std::shared_ptr<ReduceMethodCost>;
|
|||
|
||||
class ReduceMeanCost : public ReduceMethodCost {
|
||||
public:
|
||||
ReduceMeanCost() = default;
|
||||
explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {}
|
||||
ReduceMeanCost() : ReduceMethodCost(true) {}
|
||||
~ReduceMeanCost() override = default;
|
||||
|
||||
double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
@ -465,7 +505,8 @@ using ReduceMeanCostPtr = std::shared_ptr<ReduceMeanCost>;
|
|||
|
||||
class GetNextCost : public OperatorCost {
|
||||
public:
|
||||
GetNextCost() = default;
|
||||
explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
GetNextCost() : OperatorCost(false) {}
|
||||
~GetNextCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
@ -499,7 +540,8 @@ using GetNextCostPtr = std::shared_ptr<GetNextCost>;
|
|||
|
||||
class DropOutCost : public OperatorCost {
|
||||
public:
|
||||
DropOutCost() = default;
|
||||
explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
DropOutCost() : OperatorCost(true) {}
|
||||
~DropOutCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
@ -530,7 +572,8 @@ using DropOutCostPtr = std::shared_ptr<DropOutCost>;
|
|||
|
||||
class GatherV2Cost : public OperatorCost {
|
||||
public:
|
||||
GatherV2Cost() = default;
|
||||
explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
GatherV2Cost() : OperatorCost(true) {}
|
||||
~GatherV2Cost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||
|
|
|
@ -51,7 +51,7 @@ class Activation : public ActivationBase {
|
|||
public:
|
||||
Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>()) {}
|
||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {}
|
||||
~Activation() override = default;
|
||||
Status GenerateStrategies(int32_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||
|
@ -102,7 +102,7 @@ class Softmax : public ActivationBase {
|
|||
public:
|
||||
explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {}
|
||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>(false)) {}
|
||||
~Softmax() override = default;
|
||||
Status GenerateStrategies(int32_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||
|
|
|
@ -32,8 +32,8 @@ namespace parallel {
|
|||
class ArithmeticBase : public OperatorInfo {
|
||||
public:
|
||||
ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>()) {}
|
||||
const PrimitiveAttrs& attrs, OperatorCostPtr cost)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
|
||||
~ArithmeticBase() override = default;
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||
|
@ -56,7 +56,7 @@ class ArithmeticBase : public OperatorInfo {
|
|||
class SubInfo : public ArithmeticBase {
|
||||
public:
|
||||
SubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
~SubInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -64,21 +64,21 @@ class TensorAddInfo : public ArithmeticBase {
|
|||
public:
|
||||
TensorAddInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
~TensorAddInfo() override = default;
|
||||
};
|
||||
|
||||
class MulInfo : public ArithmeticBase {
|
||||
public:
|
||||
MulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
~MulInfo() override = default;
|
||||
};
|
||||
|
||||
class DivInfo : public ArithmeticBase {
|
||||
public:
|
||||
DivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
~DivInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -86,7 +86,7 @@ class RealDivInfo : public ArithmeticBase {
|
|||
public:
|
||||
RealDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
~RealDivInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -94,14 +94,14 @@ class FloorDivInfo : public ArithmeticBase {
|
|||
public:
|
||||
FloorDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
~FloorDivInfo() override = default;
|
||||
};
|
||||
|
||||
class PowInfo : public ArithmeticBase {
|
||||
public:
|
||||
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
~PowInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -109,7 +109,7 @@ class GreaterInfo : public ArithmeticBase {
|
|||
public:
|
||||
GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
~GreaterInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -117,7 +117,7 @@ class AssignSubInfo : public ArithmeticBase {
|
|||
public:
|
||||
AssignSubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
~AssignSubInfo() override = default;
|
||||
};
|
||||
} // namespace parallel
|
||||
|
|
|
@ -29,9 +29,13 @@ namespace mindspore {
|
|||
namespace parallel {
|
||||
class BatchParallelInfo : public OperatorInfo {
|
||||
public:
|
||||
BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs, OperatorCostPtr cost)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {}
|
||||
BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(false)),
|
||||
dev_num_(1) {}
|
||||
|
||||
~BatchParallelInfo() override = default;
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
|
@ -58,7 +62,7 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
|
|||
public:
|
||||
SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape,
|
||||
const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
|
||||
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(true)) {}
|
||||
~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
};
|
||||
|
|
|
@ -34,7 +34,7 @@ class BiasAddInfo : public OperatorInfo {
|
|||
public:
|
||||
BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>(false)) {}
|
||||
~BiasAddInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "ir/value.h"
|
||||
|
@ -31,7 +32,7 @@ class EqualInfo : public ArithmeticBase {
|
|||
public:
|
||||
EqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
~EqualInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -39,7 +40,7 @@ class NotEqualInfo : public ArithmeticBase {
|
|||
public:
|
||||
NotEqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
~NotEqualInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -47,7 +48,7 @@ class MaximumInfo : public ArithmeticBase {
|
|||
public:
|
||||
MaximumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
~MaximumInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -55,7 +56,7 @@ class MinimumInfo : public ArithmeticBase {
|
|||
public:
|
||||
MinimumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
~MinimumInfo() override = default;
|
||||
};
|
||||
} // namespace parallel
|
||||
|
|
|
@ -33,7 +33,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
|
|||
public:
|
||||
DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>(true)) {}
|
||||
~DropoutDoMaskInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
|
|
|
@ -32,7 +32,7 @@ class GetNextInfo : public OperatorInfo {
|
|||
public:
|
||||
GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {}
|
||||
~GetNextInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -36,7 +36,8 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
|
|||
public:
|
||||
SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs,
|
||||
std::make_shared<SoftmaxCrossEntropyWithLogitsCost>(false)) {}
|
||||
~SoftmaxCrossEntropyWithLogitsInfo() override = default;
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||
|
|
|
@ -593,11 +593,11 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr&
|
|||
// Here, we use the origin outputs_, because we only use the slice size of the output tensor.
|
||||
// It does not matter whether the output tensor is transposed or not.
|
||||
double computation_cost =
|
||||
cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
double communication_cost = cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
|
||||
result->communication_without_parameter_ =
|
||||
cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
result->communication_with_partial_para_ =
|
||||
result->communication_without_parameter_ +
|
||||
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
|
||||
|
|
|
@ -34,7 +34,7 @@ class MatMulBase : public OperatorInfo {
|
|||
public:
|
||||
MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {}
|
||||
~MatMulBase() override = default;
|
||||
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
|
|
|
@ -33,7 +33,7 @@ class OneHotInfo : public OperatorInfo {
|
|||
public:
|
||||
OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>(false)) {}
|
||||
~OneHotInfo() override = default;
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||
|
|
|
@ -1035,11 +1035,12 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
int32_t stage_id = strategy->GetInputStage();
|
||||
double computation_cost = cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
double communication_cost = cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
double computation_cost =
|
||||
operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
|
||||
result->communication_without_parameter_ =
|
||||
cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
result->communication_with_partial_para_ =
|
||||
result->communication_without_parameter_ +
|
||||
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
|
||||
|
@ -1096,7 +1097,38 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool>& is_parameter) {
|
|||
return FAILED;
|
||||
}
|
||||
is_parameter_ = is_parameter;
|
||||
cost()->set_is_parameter(is_parameter);
|
||||
operator_cost()->set_is_parameter(is_parameter);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status OperatorInfo::CalculateMemoryCost() {
|
||||
// First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to
|
||||
// calculate memory cost.
|
||||
if (is_parameter_involve_.size() != is_parameter_.size()) {
|
||||
MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'.";
|
||||
return FAILED;
|
||||
}
|
||||
operator_cost()->set_is_parameter_involve(is_parameter_involve_);
|
||||
operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
|
||||
// Set the memory cost in the 'strategy_cost_'
|
||||
for (auto& swc : strategy_cost_) {
|
||||
auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr);
|
||||
swc->cost_list[0]->memory_with_reuse_ = mem_cost;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status OperatorInfo::CorrectMemoryCost(size_t input_index) {
|
||||
for (auto& swc : strategy_cost_) {
|
||||
double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) *
|
||||
static_cast<double>(operator_cost()->inputs_type_lengths()[input_index]);
|
||||
swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost;
|
||||
if (swc->cost_list[0]->memory_with_reuse_ < 0) {
|
||||
MS_LOG(ERROR) << "The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_
|
||||
<< ", the parameter memory cost is: " << parameter_mem_cost;
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -1193,7 +1225,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t>& inpu
|
|||
}
|
||||
inputs_type_lengths_ = input_lengths;
|
||||
outputs_type_lengths_ = output_lengths;
|
||||
cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths);
|
||||
operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -1221,7 +1253,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra
|
|||
}
|
||||
|
||||
double OperatorInfo::GetForwardMemoryCostFromCNode() {
|
||||
return cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
|
||||
return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
|
||||
}
|
||||
|
||||
} // namespace parallel
|
||||
|
|
|
@ -60,7 +60,7 @@ class OperatorInfo {
|
|||
outputs_shape_(std::move(outputs_shape)),
|
||||
attrs_(std::move(attrs)),
|
||||
is_alive_(true),
|
||||
cost_(cost),
|
||||
operator_cost_(cost),
|
||||
outputs_type_() {
|
||||
std::vector<bool> not_parameteter(inputs_shape_.size(), false);
|
||||
is_parameter_ = not_parameteter;
|
||||
|
@ -83,8 +83,8 @@ class OperatorInfo {
|
|||
// Given the stage_id (which indicates the number of devices),
|
||||
// generate all strategies for this operator
|
||||
virtual Status GenerateStrategies(int32_t stage_id) = 0;
|
||||
const OperatorCostPtr& cost() const { return cost_; }
|
||||
void set_cost(const OperatorCostPtr& cost) { cost_ = cost; }
|
||||
const OperatorCostPtr& operator_cost() const { return operator_cost_; }
|
||||
void set_cost(const OperatorCostPtr& cost) { operator_cost_ = cost; }
|
||||
virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0;
|
||||
|
||||
virtual std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies();
|
||||
|
@ -98,7 +98,7 @@ class OperatorInfo {
|
|||
std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
|
||||
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
|
||||
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
|
||||
Status CalculateMemoryCost() const { return SUCCESS; }
|
||||
Status CalculateMemoryCost();
|
||||
int ComputeOpAndPrevEdgeParameterInvolved();
|
||||
|
||||
ForwardOp forward_op() const { return forward_op_; }
|
||||
|
@ -125,7 +125,7 @@ class OperatorInfo {
|
|||
void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
|
||||
void ReplacePreEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
|
||||
void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
|
||||
std::vector<size_t> GetOutputTypeLengths() const { return cost()->outputs_type_lengths(); }
|
||||
std::vector<size_t> GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); }
|
||||
void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) {
|
||||
selected_strategy_ = s_strategy;
|
||||
selected_cost_ = cost;
|
||||
|
@ -142,6 +142,10 @@ class OperatorInfo {
|
|||
void set_strategy(const StrategyPtr& strategy) { strategy_ = strategy; }
|
||||
void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); }
|
||||
const std::string& refkey_parameter_name() const { return refkey_parameter_name_; }
|
||||
// When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated
|
||||
// multiple times. This method is to correct this, and makes the cost is calulated only once.
|
||||
Status CorrectMemoryCost(size_t input_index);
|
||||
int is_output_parameter_involve() const { return is_output_parameter_involve_; }
|
||||
int used_devices() const { return used_devices_; }
|
||||
// needed by rec_parser
|
||||
void set_type(const std::string& type) { type_ = type; }
|
||||
|
@ -234,7 +238,7 @@ class OperatorInfo {
|
|||
int32_t used_devices_ = -1;
|
||||
|
||||
private:
|
||||
OperatorCostPtr cost_;
|
||||
OperatorCostPtr operator_cost_;
|
||||
std::vector<TypePtr> outputs_type_;
|
||||
};
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class PReLUInfo : public OperatorInfo {
|
|||
public:
|
||||
PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>(true)) {}
|
||||
~PReLUInfo() override = default;
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||
|
|
|
@ -109,7 +109,7 @@ Status ReduceMethod::GetAttrs() {
|
|||
}
|
||||
cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value();
|
||||
}
|
||||
auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(cost());
|
||||
auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost());
|
||||
if (reducemethodcost == nullptr) {
|
||||
MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!";
|
||||
return FAILED;
|
||||
|
|
|
@ -34,7 +34,7 @@ class ReduceMethod : public OperatorInfo {
|
|||
public:
|
||||
ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>()) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>(true)) {}
|
||||
~ReduceMethod() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -36,7 +36,7 @@ class ReshapeInfo : public OperatorInfo {
|
|||
public:
|
||||
ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()),
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)),
|
||||
dev_num_(0),
|
||||
input_layout_set_flag_(false),
|
||||
output_layout_set_flag_(false) {}
|
||||
|
|
|
@ -34,7 +34,7 @@ class TmpIdentityInfo : public OperatorInfo {
|
|||
public:
|
||||
TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs,
|
||||
const std::string& name = IDENTITY_INFO)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>()) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>(false)) {}
|
||||
~TmpIdentityInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
|
|
|
@ -35,7 +35,7 @@ class TransposeInfo : public OperatorInfo {
|
|||
public:
|
||||
TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>()) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>(false)) {}
|
||||
~TransposeInfo() override = default;
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||
|
|
|
@ -32,7 +32,7 @@ class VirtualDatasetInfo : public OperatorInfo {
|
|||
public:
|
||||
VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>()) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>(false)) {}
|
||||
~VirtualDatasetInfo() override = default;
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||
|
|
|
@ -874,11 +874,15 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
|
|||
if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
|
||||
// Calculate operators' memory usage
|
||||
if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Correcting operators' cost for memory reuse failed.";
|
||||
MS_LOG(EXCEPTION) << "Calculating operators' cost for memory cost failed.";
|
||||
}
|
||||
// Calculate edges' memory usage
|
||||
if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Correcting edges' cost for memory reuse failed.";
|
||||
MS_LOG(EXCEPTION) << "Calculating edges' cost for memory cost failed.";
|
||||
}
|
||||
// Correct memory usage caused by TmpIdentity
|
||||
if (entire_costgraph->CorrectOpsMemoryCost() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Correcting operators' cost for memory cost failed.";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Computing operators' parameter_involved failed.";
|
||||
|
|
|
@ -159,6 +159,7 @@ Status TensorRedistribution::ComputeCost() {
|
|||
backward_comm_cost_ += prod;
|
||||
comm_cost_ += 2.0 * prod;
|
||||
computation_cost_ += prod;
|
||||
memory_cost_ += prod;
|
||||
} else if (str == CONCAT_BY_AXIS) {
|
||||
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
|
||||
// computation cost = before_slice_shape
|
||||
|
@ -175,20 +176,25 @@ Status TensorRedistribution::ComputeCost() {
|
|||
if (concat_dim == 0) {
|
||||
// computation cost = all_gather
|
||||
computation_cost_ += prod;
|
||||
memory_cost_ += prod * dev_num;
|
||||
} else {
|
||||
// computation cost = all_gather + split + concat
|
||||
computation_cost_ += (prod + prod * dev_num + prod * dev_num);
|
||||
memory_cost_ += (prod * dev_num + prod * dev_num + prod);
|
||||
}
|
||||
} else {
|
||||
// There is only computation cost in SplitByAxis.
|
||||
// computation cost = before_slice_shape
|
||||
computation_cost_ += prod;
|
||||
// This addtion may be erroneous
|
||||
memory_cost_ += prod;
|
||||
}
|
||||
}
|
||||
if (reshape_flag()) {
|
||||
Shape prev_slice_shape = from_.slice_shape().array();
|
||||
double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>());
|
||||
computation_cost_ += 2.0 * prev_prod;
|
||||
memory_cost_ += 2.0 * prev_prod;
|
||||
}
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
|
|
@ -42,6 +42,7 @@ class TensorRedistribution {
|
|||
forward_comm_cost_(0.0),
|
||||
backward_comm_cost_(0.0),
|
||||
computation_cost_(0.0),
|
||||
memory_cost_(0.0),
|
||||
construct_op_flag_(construct_op_flag),
|
||||
keep_reshape_(keep_reshape) {}
|
||||
Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list);
|
||||
|
@ -54,6 +55,7 @@ class TensorRedistribution {
|
|||
double computation_cost() const { return computation_cost_; }
|
||||
double forward_comm_cost() const { return forward_comm_cost_; }
|
||||
double backward_comm_cost() const { return backward_comm_cost_; }
|
||||
double memory_cost() const { return memory_cost_; }
|
||||
|
||||
private:
|
||||
Status InferReshape(const TensorLayout& from_layout, const TensorLayout& to_layout,
|
||||
|
@ -72,7 +74,12 @@ class TensorRedistribution {
|
|||
double forward_comm_cost_;
|
||||
// backward communication cost
|
||||
double backward_comm_cost_;
|
||||
// computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the
|
||||
// inputs.
|
||||
double computation_cost_;
|
||||
// memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is
|
||||
// calculated by the outputs.
|
||||
double memory_cost_;
|
||||
bool construct_op_flag_;
|
||||
bool keep_reshape_;
|
||||
};
|
||||
|
|
|
@ -84,9 +84,9 @@ TEST_F(TestActivation, test_activation_strategies) {
|
|||
act_ptr_->InitForCostModel(sp);
|
||||
std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info();
|
||||
std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info();
|
||||
ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
cost.computation_cost_);
|
||||
ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
cost.communication_cost_);
|
||||
}
|
||||
}
|
||||
|
@ -109,9 +109,9 @@ TEST_F(TestActivation, test_softmax_strategies) {
|
|||
soft_ptr_->InitForCostModel(sp);
|
||||
std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info();
|
||||
std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info();
|
||||
ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
cost.computation_cost_);
|
||||
ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
cost.communication_cost_);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -569,7 +569,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
|
|||
matmul1->InitForCostModel(sp);
|
||||
std::vector<TensorInfo> inputs_info = matmul1->inputs_tensor_info();
|
||||
std::vector<TensorInfo> outputs_info = matmul1->outputs_tensor_info();
|
||||
ASSERT_DOUBLE_EQ(matmul1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
ASSERT_DOUBLE_EQ(matmul1->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
cost.computation_cost_);
|
||||
break;
|
||||
}
|
||||
|
@ -599,7 +599,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) {
|
|||
TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape);
|
||||
replica_inputs_info.push_back(replica_input1_info);
|
||||
|
||||
ASSERT_DOUBLE_EQ(matmul3->cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()),
|
||||
ASSERT_DOUBLE_EQ(matmul3->operator_cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()),
|
||||
cost.computation_cost_);
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -188,11 +188,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) {
|
|||
tensor_add->InitForCostModel(sp);
|
||||
std::vector<TensorInfo> inputs_info = tensor_add->inputs_tensor_info();
|
||||
std::vector<TensorInfo> outputs_info = tensor_add->outputs_tensor_info();
|
||||
double memory_cost0 = tensor_add->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||
double memory_cost0 = tensor_add->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||
double memory_cost1 = cost.computation_cost_;
|
||||
bool memory = memory_cost0 - memory_cost1 <= 1.0;
|
||||
|
||||
double comm_cost0 = tensor_add->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||
double comm_cost0 = tensor_add->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||
double comm_cost1 = cost.communication_cost_;
|
||||
bool comm = comm_cost0 - comm_cost1 <= 1.0;
|
||||
|
||||
|
@ -210,11 +210,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) {
|
|||
tensor_add1->InitForCostModel(sp);
|
||||
std::vector<TensorInfo> inputs_info = tensor_add1->inputs_tensor_info();
|
||||
std::vector<TensorInfo> outputs_info = tensor_add1->outputs_tensor_info();
|
||||
double memory_cost0 = tensor_add1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||
double memory_cost0 = tensor_add1->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||
double memory_cost1 = cost.computation_cost_;
|
||||
bool memory = memory_cost0 - memory_cost1 <= 1.0;
|
||||
|
||||
double comm_cost0 = tensor_add1->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||
double comm_cost0 = tensor_add1->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||
double comm_cost1 = cost.communication_cost_;
|
||||
bool comm = comm_cost0 - comm_cost1 <= 1.0;
|
||||
|
||||
|
|
|
@ -145,9 +145,9 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) {
|
|||
identity_ptr->Init(sp);
|
||||
std::vector<TensorInfo> inputs_info = identity_ptr->inputs_tensor_info();
|
||||
std::vector<TensorInfo> outputs_info = identity_ptr->outputs_tensor_info();
|
||||
ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
ASSERT_DOUBLE_EQ(identity_ptr->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
cost.computation_cost_);
|
||||
ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
ASSERT_DOUBLE_EQ(identity_ptr->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||
cost.communication_cost_);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue