forked from mindspore-Ecosystem/mindspore
change_star_elimination: make the non-identity triangle_eliminatin exact
This commit is contained in:
parent
29ab2c1093
commit
9a717aa1f7
|
@ -948,10 +948,12 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr& op) {
|
|||
return target_op;
|
||||
}
|
||||
|
||||
void CostGraph::CreateTriangleEliminationSubCostListForIdentity(
|
||||
StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, StrategyPtr right_op_stra, const CostPtr& right_op_cost,
|
||||
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost,
|
||||
const CostPtrList& left_node_clist_origin, CostPtrList* left_node_clist_new) {
|
||||
void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra,
|
||||
StrategyPtr right_op_stra, const CostPtr& right_op_cost,
|
||||
const CostPtrList& elimi_op_clist,
|
||||
const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost,
|
||||
const CostPtrList& left_node_clist_origin,
|
||||
CostPtrList* left_node_clist_new) {
|
||||
MS_EXCEPTION_IF_NULL(right_edge_cost);
|
||||
MS_EXCEPTION_IF_NULL(right_op_cost);
|
||||
MS_EXCEPTION_IF_NULL(left_node_clist_new);
|
||||
|
@ -985,93 +987,20 @@ void CostGraph::CreateTriangleEliminationSubCostListForIdentity(
|
|||
}
|
||||
}
|
||||
|
||||
void CostGraph::CreateTriangleEliminationSubCostListForOthers(
|
||||
StrategyPtr elimi_op_stra, StrategyPtr left_node_stra, StrategyPtr right_node_stra, const CostPtr& right_op_cost,
|
||||
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost,
|
||||
const CostPtrList& left_node_clist_origin, CostPtrList* left_node_clist_new) {
|
||||
CostPtr elimi_op_determined = nullptr, left_edge_determined = nullptr, init_ele = nullptr;
|
||||
std::function<CostPtr(CostPtr, const CostPtr&)> LocalCompare = [&](CostPtr init, const CostPtr& cost_x) {
|
||||
MS_EXCEPTION_IF_NULL(cost_x);
|
||||
if ((init == nullptr) || (cost_x->memory_cost_ < DEVICE_MEMORY_CAPACITY)) {
|
||||
init = cost_x;
|
||||
}
|
||||
return init;
|
||||
};
|
||||
|
||||
// Find a feasible elimi_op_clist
|
||||
elimi_op_determined = std::accumulate(elimi_op_clist.begin(), elimi_op_clist.end(), init_ele, LocalCompare);
|
||||
init_ele = nullptr;
|
||||
// Find a feasible left_edge_cost
|
||||
left_edge_determined = std::accumulate(left_edge_clist.begin(), left_edge_clist.end(), init_ele, LocalCompare);
|
||||
if ((elimi_op_determined == nullptr) || (left_edge_determined == nullptr)) {
|
||||
return;
|
||||
}
|
||||
if ((elimi_op_determined->memory_cost_ >= DEVICE_MEMORY_CAPACITY) ||
|
||||
(left_edge_determined->memory_cost_ >= DEVICE_MEMORY_CAPACITY)) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto& left_node_cost : left_node_clist_origin) {
|
||||
MS_EXCEPTION_IF_NULL(left_node_cost);
|
||||
MS_EXCEPTION_IF_NULL(right_op_cost);
|
||||
double new_memory_cost = left_node_cost->memory_cost_ + elimi_op_determined->memory_cost_ +
|
||||
left_edge_determined->memory_cost_ + right_edge_cost->memory_cost_ +
|
||||
right_op_cost->memory_cost_;
|
||||
double commu_cost = left_node_cost->communication_cost_ + elimi_op_determined->communication_cost_ +
|
||||
left_edge_determined->communication_cost_ + right_edge_cost->communication_cost_ +
|
||||
right_op_cost->communication_cost_;
|
||||
double commu_without =
|
||||
left_node_cost->communication_without_parameter_ + elimi_op_determined->communication_without_parameter_ +
|
||||
left_edge_determined->communication_without_parameter_ + right_edge_cost->communication_without_parameter_ +
|
||||
right_op_cost->communication_without_parameter_;
|
||||
auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_determined,
|
||||
left_edge_determined, right_edge_cost, left_node_stra,
|
||||
left_node_cost, right_node_stra, right_op_cost);
|
||||
|
||||
auto new_cost = std::make_shared<Cost>(new_memory_cost, commu_cost, decision);
|
||||
new_cost->communication_without_parameter_ = commu_without;
|
||||
new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
|
||||
left_node_clist_new->emplace_back(std::move(new_cost));
|
||||
}
|
||||
}
|
||||
|
||||
void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_op, const CostPtrList& right_node_clist,
|
||||
const CostPtrList& right_edge_clist, const StrategyPtr& elimi_op_stra,
|
||||
const StrategyPtr& left_node_stra, const StrategyPtr& right_node_stra,
|
||||
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist,
|
||||
const CostPtrList& left_node_clist_origin,
|
||||
CostPtrList* left_node_clist_new) {
|
||||
// The reason for separately dealing with when the 'elimi_op' is 'TMPIDENTITY_INFO' or others is that
|
||||
// when 'elimi_op' is TMPIDENTITY_INFO, the computation is limited, while 'elimi_op' is others, the computation
|
||||
// may be huge
|
||||
MS_EXCEPTION_IF_NULL(elimi_op);
|
||||
if (elimi_op->name().find(TMPIDENTITY_INFO_NAME) != std::string::npos) {
|
||||
for (auto& right_node_cost : right_node_clist) {
|
||||
MS_EXCEPTION_IF_NULL(right_node_cost);
|
||||
for (auto& right_edge_cost : right_edge_clist) {
|
||||
MS_EXCEPTION_IF_NULL(right_edge_cost);
|
||||
if ((right_node_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY) &&
|
||||
(right_edge_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY)) {
|
||||
// Exact computation for TMPIDENTITY_INFO_NAME case
|
||||
CreateTriangleEliminationSubCostListForIdentity(elimi_op_stra, left_node_stra, right_node_stra,
|
||||
right_node_cost, elimi_op_clist, left_edge_clist,
|
||||
right_edge_cost, left_node_clist_origin, left_node_clist_new);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (auto& right_node_cost : right_node_clist) {
|
||||
MS_EXCEPTION_IF_NULL(right_node_cost);
|
||||
for (auto& right_edge_cost : right_edge_clist) {
|
||||
MS_EXCEPTION_IF_NULL(right_edge_cost);
|
||||
if ((right_node_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY) &&
|
||||
(right_edge_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY)) {
|
||||
// Approximate computation for other case
|
||||
CreateTriangleEliminationSubCostListForOthers(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
|
||||
elimi_op_clist, left_edge_clist, right_edge_cost,
|
||||
left_node_clist_origin, left_node_clist_new);
|
||||
}
|
||||
}
|
||||
for (auto& right_node_cost : right_node_clist) {
|
||||
MS_EXCEPTION_IF_NULL(right_node_cost);
|
||||
for (auto& right_edge_cost : right_edge_clist) {
|
||||
MS_EXCEPTION_IF_NULL(right_edge_cost);
|
||||
CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
|
||||
elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin,
|
||||
left_node_clist_new);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -163,14 +163,9 @@ class CostGraph {
|
|||
void CreateTriangleEliminationCostList(const OperatorInfoPtr&, const CostPtrList&, const CostPtrList&,
|
||||
const StrategyPtr&, const StrategyPtr&, const StrategyPtr&, const CostPtrList&,
|
||||
const CostPtrList&, const CostPtrList&, CostPtrList*);
|
||||
// Given the relevant costlist, create the TriangleElimination cost for eliminating TmpIdentityInfo
|
||||
void CreateTriangleEliminationSubCostListForIdentity(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&,
|
||||
const CostPtrList&, const CostPtrList&, const CostPtr&,
|
||||
const CostPtrList&, CostPtrList*);
|
||||
// Given the relevant costlist, create the TriangleElimination cost for eliminating other operators
|
||||
void CreateTriangleEliminationSubCostListForOthers(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&,
|
||||
const CostPtrList&, const CostPtrList&, const CostPtr&,
|
||||
const CostPtrList&, CostPtrList*);
|
||||
// Given the relevant costlist, create the TriangleElimination cost
|
||||
void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&, const CostPtrList&,
|
||||
const CostPtrList&, const CostPtr&, const CostPtrList&, CostPtrList*);
|
||||
|
||||
// Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op
|
||||
// NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
|
||||
|
|
Loading…
Reference in New Issue