code style warnings fixing

This commit is contained in:
Xiaoda Zhang 2021-04-27 15:13:29 +08:00
parent b7d350e8c6
commit 5fecfe92a6
20 changed files with 462 additions and 461 deletions

View File

@ -23,7 +23,8 @@
namespace mindspore {
namespace parallel {
void Simplify(CostPtrList *clist_ptrs) {
if (RUN_PHASE == TRAINING_PHASE) {
const auto run_phase = CostModelContext::GetInstance()->run_phase();
if (run_phase == TRAINING_PHASE) {
// training phase
SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs);
} else {
@ -35,7 +36,8 @@ void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
// Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method
// excludes the cost with greater computation_cost_ and greater communication_forward.
// E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
if (!COST_MODEL_SIMPLIFY_CALCULATION) {
const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal();
if (!simplify_cal) {
return;
}
MS_EXCEPTION_IF_NULL(clist_ptrs);
@ -57,7 +59,8 @@ void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
// Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
// order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
if (!COST_MODEL_SIMPLIFY_CALCULATION) {
const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal();
if (!simplify_cal) {
return;
}
MS_EXCEPTION_IF_NULL(clist_ptrs);
@ -78,19 +81,23 @@ void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs)
void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) {
MS_EXCEPTION_IF_NULL(origin_cost);
const auto comm_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold();
const auto comm_const = CostModelContext::GetInstance()->costmodel_communi_const();
const auto comm_bias = CostModelContext::GetInstance()->costmodel_communi_bias();
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
if (is_redistribution) {
// Redistribution cost
if ((origin_cost->communication_redis_forward_ > EPS) &&
(origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST;
} else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) {
origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS;
(origin_cost->communication_redis_forward_ <= comm_threshold)) {
origin_cost->communication_redis_forward_ = comm_const;
} else if (origin_cost->communication_redis_forward_ > comm_threshold) {
origin_cost->communication_redis_forward_ += comm_bias;
}
if ((origin_cost->communication_redis_backward_ > EPS) &&
(origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST;
} else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) {
origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS;
(origin_cost->communication_redis_backward_ <= comm_threshold)) {
origin_cost->communication_redis_backward_ = comm_const;
} else if (origin_cost->communication_redis_backward_ > comm_threshold) {
origin_cost->communication_redis_backward_ += comm_bias;
}
origin_cost->communication_cost_ =
origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_;
@ -104,18 +111,17 @@ void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution)
}
// forward cost
if ((origin_cost->communication_without_parameter_ > EPS) &&
(origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST;
} else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) {
origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS;
(origin_cost->communication_without_parameter_ <= comm_threshold)) {
origin_cost->communication_without_parameter_ = comm_const;
} else if (origin_cost->communication_without_parameter_ > comm_threshold) {
origin_cost->communication_without_parameter_ += comm_bias;
}
// total
if (origin_cost->communication_cost_ > EPS) {
origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward;
}
if (origin_cost->communication_with_partial_para_ > EPS) {
origin_cost->communication_with_partial_para_ =
origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward;
origin_cost->communication_with_partial_para_ = origin_cost->communication_without_parameter_ + gamma * backward;
}
}
}

View File

@ -24,6 +24,7 @@
#include <vector>
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_info.h"
#include "frontend/parallel/costmodel_context.h"
namespace mindspore {
namespace parallel {
@ -44,8 +45,8 @@ using RedistributionOpListPtr = std::shared_ptr<std::pair<OperatorVector, OutPut
struct Cost {
Cost();
Cost(double computation, double commuication, const std::shared_ptr<Decision> &decision_ = nullptr)
: computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) {
Cost(double computation, double communication, const std::shared_ptr<Decision> &decision_ = nullptr)
: computation_cost_(computation), communication_cost_(communication), decision_ptr_(std::move(decision_)) {
memory_with_reuse_ = 0.0;
communication_without_parameter_ = 0.0;
communication_with_partial_para_ = 0.0;
@ -273,7 +274,7 @@ struct TriangleEliminationDecision : public Decision {
*
* v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges.
* In addition, v and w have other complicated connections, resulting in v and w can not be performed other
* eliminations. After the StarElimination, u is merged into v, and the resulting graph is splitted into multiple
* eliminations. After the StarElimination, u is merged into v, and the resulting graph is split into multiple
* connected components.
* NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
*/

View File

@ -132,18 +132,14 @@ Status GetStrategy(const CostGraphPtr &graph) {
Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
std::vector<EliminationPtr>::reverse_iterator rit;
const auto triangle_star_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) {
if ((*rit)->isa<OpElimination>()) {
auto elimination = (*rit)->cast<OpEliminationPtr>();
auto e = elimination->new_edge_;
auto w = elimination->op_;
MS_EXCEPTION_IF_NULL(e);
MS_EXCEPTION_IF_NULL(w);
auto left_edge = elimination->left_edge_;
auto right_edge = elimination->right_edge_;
MS_EXCEPTION_IF_NULL(left_edge);
MS_EXCEPTION_IF_NULL(right_edge);
auto decision = e->selected_cost()->decision_ptr_->cast<OpEliminationDecisionPtr>();
w->SetSelectedStrategyAndCost(decision->op_strategy_, decision->middle_cost_);
left_edge->set_selected_cost(decision->left_cost_);
@ -201,12 +197,12 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
right_edge->set_selected_cost(decision->right_edge_cost_);
// 'left_node' recovers the strategy.
left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_);
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
if (triangle_star_overwrite) {
// 'right_node' recovers the strategy.
MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination.";
right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_);
} else {
// In this case, 'right_node' is not overwriten strategy, and it checks strategy consistency.
// In this case, 'right_node' is not overwritten strategy, and it checks strategy consistency.
right_node->CheckSelectedStrategy(decision->right_node_strategy_);
}
MS_LOG(INFO) << "Recover triangleElimination succeeded.";
@ -215,7 +211,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
auto merged_node = elimination->eliminated_node_;
auto succ_edges = elimination->succ_edges_;
auto succ_nodes = elimination->succ_ops_;
// decision is hided in succ_nodes[0]
// decision is hidden in succ_nodes[0]
auto decision = succ_nodes[0]->selected_cost()->decision_ptr_->cast<StarEliminationDecisionPtr>();
merged_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_);
@ -228,7 +224,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
// Star is eliminated into 'succ_nodes[0]'
succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]);
for (size_t k = 1; k < succ_nodes.size(); ++k) {
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
if (triangle_star_overwrite) {
// 'succ_nodes[k]' is overwritten strategy and cost.
succ_nodes[k]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[k], decision->succ_ops_cost_list_[k]);
} else {

View File

@ -90,11 +90,13 @@ Status Edge::InitEdgeCost() {
}
}
if (!has_available_cost) {
if (FULLY_USE_DEVICES) {
const auto fully_use = CostModelContext::GetInstance()->fully_use_device();
const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
if (fully_use) {
MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
<< " failed, it may be caused by setting 'fully_use_devices' true. Try to set "
"'fully_use_devices' false.";
} else if (ELEMENTWISE_OP_STRA_FOLLOW) {
} else if (stra_follow) {
MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
<< " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. "
"Try to set 'elementwise_op_strategy_follow' false.";
@ -130,6 +132,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co
double backward_comm_cost = tensor_redistribution.backward_comm_cost();
double computation_cost = tensor_redistribution.computation_cost();
double mem_cost = tensor_redistribution.memory_cost();
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
// Now AllGather, ReduceScatter, AlltoAll don't support bool type
MS_EXCEPTION_IF_NULL(type);
@ -142,7 +145,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co
(*cost)->communication_without_parameter_ = type_length * comm_cost;
(*cost)->communication_with_partial_para_ =
(*cost)->communication_without_parameter_ +
COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
gamma * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
(*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
(*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
(*cost)->memory_with_reuse_ = mem_cost;
@ -173,13 +176,14 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
std::function<void(size_t, double, double, double, double, double)> recursive =
[&](size_t k, double computation, double memory, double communication, double communication_without_para,
double communication_forward) {
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
if (k == edges.size()) {
auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
CostPtr new_cost = std::make_shared<Cost>(computation, communication);
MS_EXCEPTION_IF_NULL(new_cost);
new_cost->communication_without_parameter_ = communication_without_para;
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
communication_without_para + gamma * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
new_cost->decision_ptr_ = decision;
@ -242,10 +246,11 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
auto cost = std::make_shared<Cost>(computation, communication, decision);
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
MS_EXCEPTION_IF_NULL(cost);
cost->communication_without_parameter_ = communication_without_para;
cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
communication_without_para + gamma * (communication - communication_without_para);
cost->memory_with_reuse_ = memory_cost;
cost->communication_forward_ = communication_forward;
ret_cost_list->emplace_back(std::move(cost));

View File

@ -28,175 +28,6 @@ namespace mindspore {
namespace parallel {
CostGraphPtr entire_costgraph = nullptr;
size_t TOTAL_OPS = 0;
double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA;
bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION;
double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY;
double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD;
double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST;
double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS;
bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
int64_t RUN_PHASE = DEFAULT_RUN_PHASE;
bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE;
bool DP_ALGO_ENABLE_APPROX = DEFAULT_DP_ALGO_ENABLE_APPROX;
double DP_ALGO_APPROX_EPSILON = DEFAULT_DP_ALGO_APPROX_EPSILON;
bool DP_ALGO_SINGLE_LOOP = DEFAULT_DP_ALGO_SINGLE_LOOP;
void CostGraph::SetDeviceMemoryAndCostParameter() {
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
// DEVICE_MEMORY_CAPACITY
auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
if (device_memory <= 0) {
MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive.";
}
dev_memory_ = device_memory;
DEVICE_MEMORY_CAPACITY = device_memory;
MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << ".";
// COST_MODEL_ALPHA
auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
if (alpha <= 0) {
MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive.";
}
costmodel_alpha_ = alpha;
MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << ".";
// COST_MODEL_BETA
auto beta = CostModelContext::GetInstance()->costmodel_beta();
if (beta <= 0) {
MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive.";
}
costmodel_beta_ = beta;
MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << ".";
// COST_MODEL_GAMMA
auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
if ((gamma < 0) || (gamma > 1)) {
MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1].";
}
COST_MODEL_GAMMA = gamma;
MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << ".";
// COST_MODEL_SIMPLIFY_CALCULATION
auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal();
COST_MODEL_SIMPLIFY_CALCULATION = simplify;
if (COST_MODEL_SIMPLIFY_CALCULATION) {
MS_LOG(INFO) << "costmodel_simplify_cal: true.";
} else {
MS_LOG(INFO) << "costmodel_simplify_cal: false.";
}
// COST_MODEL_COMMUNI_THRESHOLD
auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold();
if (communi_threshold < 0) {
MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero.";
}
COST_MODEL_COMMUNI_THRESHOLD = communi_threshold;
MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << ".";
// COST_MODEL_COMMUNI_CONST
auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const();
if (communi_const < 0) {
MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero.";
}
COST_MODEL_COMMUNI_CONST = communi_const;
MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << ".";
// COST_MODEL_COMMUNI_BIAS
auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias();
if (communi_bias < 0) {
MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero.";
}
COST_MODEL_COMMUNI_BIAS = communi_bias;
MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << ".";
// TENSOR_SLICE_ALIGNMENT_ENABLE
auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable();
TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable;
if (TENSOR_SLICE_ALIGNMENT_ENABLE) {
MS_LOG(INFO) << "tensor_slice_align_enable: true.";
} else {
MS_LOG(INFO) << "tensor_slice_align_enable: false.";
}
// TENSOR_SLICE_ALIGNMENT_SIZE
auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size();
if (align_size == 0) {
MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive.";
}
TENSOR_SLICE_ALIGNMENT_SIZE = align_size;
MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << ".";
// FULLY_USE_DEVICES
auto fully_devices = CostModelContext::GetInstance()->fully_use_device();
FULLY_USE_DEVICES = fully_devices;
if (FULLY_USE_DEVICES) {
MS_LOG(INFO) << "fully_use_devices: true.";
} else {
MS_LOG(INFO) << "fully_use_devices: false.";
}
// ELEMENTWISE_OP_STRA_FOLLOW
auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow;
if (ELEMENTWISE_OP_STRA_FOLLOW) {
MS_LOG(INFO) << "elementwise_op_strategy_follow: true.";
} else {
MS_LOG(INFO) << "elementwise_op_strategy_follow: false.";
}
// MULTI_SUBGRAPHS
auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs();
MULTI_SUBGRAPHS = multi_subgraphs;
if (MULTI_SUBGRAPHS) {
MS_LOG(INFO) << "multi_subgraphs: true.";
} else {
MS_LOG(INFO) << "multi_subgraphs: false.";
}
auto overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
TRIANGLE_STAR_STRATEGY_OVERWRITE = overwrite;
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
MS_LOG(INFO) << "triangle_star_strategy_overwrite: true.";
} else {
MS_LOG(INFO) << "triangle_star_strategy_overwrite: false.";
}
// RUN_PHASE
auto phase = CostModelContext::GetInstance()->run_phase();
if (phase != 0 && phase != 1) {
MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}";
}
RUN_PHASE = phase;
MS_LOG(INFO) << "run_phase: " << RUN_PHASE << ".";
auto enable_approx = CostModelContext::GetInstance()->dp_algo_enable_approxi();
DP_ALGO_ENABLE_APPROX = enable_approx;
if (enable_approx) {
MS_LOG(INFO) << "dp_algo_enable_approx: true.";
} else {
MS_LOG(INFO) << "dp_algo_enable_approx: false.";
}
auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon();
if (epsilon <= 0 || epsilon > 1) {
MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]";
}
DP_ALGO_APPROX_EPSILON = epsilon;
MS_LOG(INFO) << "epsilon: " << epsilon << ".";
auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
DP_ALGO_SINGLE_LOOP = single_loop;
if (single_loop) {
MS_LOG(INFO) << "dp_algo_single_loop: true.";
} else {
MS_LOG(INFO) << "dp_algo_single_loop: false.";
}
}
void CostGraph::Init() {
inputs_tensor_name_list_.clear();
@ -269,7 +100,6 @@ std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents(
if ((!visited[op]) && op->is_alive()) {
std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>();
MS_EXCEPTION_IF_NULL(new_component);
new_component->SetDeviceMemoryAndCostParameter();
DFS(op, &visited, new_component);
connected_compoents_.push_back(new_component);
}
@ -336,10 +166,11 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
auto decision =
std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
auto cost = std::make_shared<Cost>(computation, communication, decision);
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
MS_EXCEPTION_IF_NULL(cost);
cost->communication_without_parameter_ = communication_without_para;
cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
communication_without_para + gamma * (communication - communication_without_para);
cost->memory_with_reuse_ = memory;
cost->communication_forward_ = communication_forward;
ret.push_back(cost);
@ -353,7 +184,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
return ret;
}
// Create final cost list for the graph containing a signle node: u
// Create final cost list for the graph containing a single node: u
CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
MS_EXCEPTION_IF_NULL(u);
CostPtrList ret;
@ -365,11 +196,12 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
MS_EXCEPTION_IF_NULL(cost1);
auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1);
auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision);
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
MS_EXCEPTION_IF_NULL(new_cost);
new_cost->communication_without_parameter_ = cost1->communication_without_parameter_;
new_cost->communication_with_partial_para_ =
cost1->communication_without_parameter_ +
COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_);
gamma * (cost1->communication_cost_ - cost1->communication_without_parameter_);
new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
new_cost->communication_forward_ = cost1->communication_forward_;
ret.push_back(new_cost);
@ -404,8 +236,10 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list,
}
// Init the returned value with first cost.
CostPtr ret = after_mem_filter[0];
const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
const auto beta = CostModelContext::GetInstance()->costmodel_beta();
double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_;
double minimum = alpha * ret->computation_cost_ + beta * ret->communication_forward_;
MS_LOG(INFO) << "Cost 0: "
<< "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
<< ", communication_forward_: " << ret->communication_forward_
@ -422,8 +256,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list,
<< ", communication_cost_: " << after_mem_filter[i]->communication_cost_
<< ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
<< ".";
auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
costmodel_beta_ * after_mem_filter[i]->communication_forward_;
auto tmp = alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_forward_;
MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
if (minimum > tmp) {
minimum = tmp;
@ -458,8 +291,10 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d
}
// Init the returned value with first cost.
CostPtr ret = after_mem_filter[0];
const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
const auto beta = CostModelContext::GetInstance()->costmodel_beta();
double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_;
double minimum = alpha * ret->computation_cost_ + beta * ret->communication_with_partial_para_;
MS_LOG(INFO) << "Cost 0: "
<< "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
@ -474,8 +309,8 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d
<< ", communication_cost_: " << after_mem_filter[i]->communication_cost_
<< ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
<< ".";
auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_;
auto tmp =
alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_with_partial_para_;
MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
if (minimum > tmp) {
minimum = tmp;
@ -513,14 +348,16 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect
}
std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive,
&available_memory, this](size_t k) {
&available_memory](size_t k) {
const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
const auto beta = CostModelContext::GetInstance()->costmodel_beta();
if (k == all_cost_list.size()) {
double tmp_memory = 0.0, tmp_minimum = 0.0;
for (size_t i = 0; i < selected_cost_list.size(); ++i) {
MS_EXCEPTION_IF_NULL(selected_cost_list[i]);
tmp_memory += selected_cost_list[i]->memory_with_reuse_;
tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ +
costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_;
tmp_minimum += alpha * selected_cost_list[i]->computation_cost_ +
beta * selected_cost_list[i]->communication_with_partial_para_;
}
MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum
<< ".";
@ -582,12 +419,12 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<Operato
<< " operators in a component in the final graph.";
}
}
//
auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
for (size_t k = 0; k < selected_cost_list.size(); ++k) {
auto selected_cost = selected_cost_list[k];
if (selected_cost == nullptr) {
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
return FAILED;
}
MS_EXCEPTION_IF_NULL(connected_components[k]);
@ -627,6 +464,99 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<Operato
return SUCCESS;
}
Status CostGraph::SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
// In this case, the final graph should contains exactly 2 nodes.
if (alive_ops.empty()) {
MS_LOG(INFO) << "0 Operator in the final graph.";
return SUCCESS;
}
OperatorInfoPtr u, v;
MS_EXCEPTION_IF_NULL(alive_ops[0]);
MS_EXCEPTION_IF_NULL(alive_ops[1]);
const auto phase = CostModelContext::GetInstance()->run_phase();
const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
if (!alive_ops[0]->GetAliveSuccEdges().empty() &&
alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) {
u = alive_ops[0];
v = alive_ops[1];
} else if (!alive_ops[1]->GetAliveSuccEdges().empty() &&
alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) {
u = alive_ops[1];
v = alive_ops[0];
} else {
if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) {
MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size()
<< ", " << alive_ops[1]->GetAliveSuccEdges().size() << ".";
} else {
// In this case, the final graph consists of two single nodes
MS_LOG(INFO) << "There are 2 single nodes in the final graph.";
std::vector<CostPtrList> all_list;
auto connected_components = ConstructConnectedComponents(alive_ops);
MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
for (size_t i = 0; i < connected_components.size(); ++i) {
MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
auto one_component = connected_components[i];
MS_EXCEPTION_IF_NULL(one_component);
auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
all_list.push_back(cost_list);
}
CostPtrList selected_cost_list;
if (phase == TRAINING_PHASE) {
// training phase
selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
} else {
// inference phase
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
"phase is not supported.";
}
for (size_t k = 0; k < selected_cost_list.size(); ++k) {
auto selected_cost = selected_cost_list[k];
if (selected_cost == nullptr) {
MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity
<< ".";
return FAILED;
}
MS_EXCEPTION_IF_NULL(connected_components[k]);
auto one_operator = connected_components[k]->GetOperators()[0];
MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
MS_EXCEPTION_IF_NULL(decision);
one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
}
return SUCCESS;
}
}
MS_LOG(INFO) << "There are 2 nodes in the final graph.";
// In this case, the finale graph is exactly of the form: u --> v
MS_EXCEPTION_IF_NULL(u);
MS_EXCEPTION_IF_NULL(v);
auto e = u->GetAliveSuccEdges()[0];
MS_EXCEPTION_IF_NULL(e);
auto cost_list = CreateFinalCostList(u, e, v);
CostPtr cost = nullptr;
if (phase == TRAINING_PHASE) {
// training phase
cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity);
} else {
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
"phase is not supported.";
}
if (cost == nullptr) {
MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
return FAILED;
}
MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
auto decision = cost->decision_ptr_->cast<FinalDecisionPtr>();
MS_EXCEPTION_IF_NULL(decision);
u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
e->set_selected_cost(decision->middle_cost_);
MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
return SUCCESS;
}
// searching the strategy for the final eliminated graph
Status CostGraph::SearchStrategy() {
MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began.";
@ -637,9 +567,11 @@ Status CostGraph::SearchStrategy() {
alive_ops.push_back(op);
}
});
const auto phase = CostModelContext::GetInstance()->run_phase();
const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
if (alive_ops.size() > 2) {
if (RUN_PHASE == TRAINING_PHASE) {
if (phase == TRAINING_PHASE) {
// training phase
return SearchStrategyForMultiNodeFinalGraph(alive_ops);
} else {
@ -652,15 +584,15 @@ Status CostGraph::SearchStrategy() {
OperatorInfoPtr u = alive_ops[0];
auto cost_list = CreateFinalSingleCostList(u);
CostPtr cost = nullptr;
if (RUN_PHASE == TRAINING_PHASE) {
if (phase == TRAINING_PHASE) {
// training phase
cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity);
} else {
// inference phase
cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_);
cost = SelectCostWithMinInferenceTime(cost_list, device_mem_capacity);
}
if (cost == nullptr) {
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
return FAILED;
}
MS_EXCEPTION_IF_NULL(u);
@ -671,93 +603,7 @@ Status CostGraph::SearchStrategy() {
MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
return SUCCESS;
} else {
// In this case, the final graph should contains exactly 2 nodes.
if (alive_ops.empty()) {
MS_LOG(INFO) << "0 Operator in the final graph.";
return SUCCESS;
}
OperatorInfoPtr u, v;
MS_EXCEPTION_IF_NULL(alive_ops[0]);
MS_EXCEPTION_IF_NULL(alive_ops[1]);
if (!alive_ops[0]->GetAliveSuccEdges().empty() &&
alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) {
u = alive_ops[0];
v = alive_ops[1];
} else if (!alive_ops[1]->GetAliveSuccEdges().empty() &&
alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) {
u = alive_ops[1];
v = alive_ops[0];
} else {
if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) {
MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size()
<< ", " << alive_ops[1]->GetAliveSuccEdges().size() << ".";
} else {
// In this case, the final graph consists of two single nodes
MS_LOG(INFO) << "There are 2 single nodes in the final graph.";
std::vector<CostPtrList> all_list;
auto connected_components = ConstructConnectedComponents(alive_ops);
MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
for (size_t i = 0; i < connected_components.size(); ++i) {
MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
auto one_component = connected_components[i];
MS_EXCEPTION_IF_NULL(one_component);
auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
all_list.push_back(cost_list);
}
CostPtrList selected_cost_list;
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
} else {
// inference phase
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
"phase is not supported.";
}
for (size_t k = 0; k < selected_cost_list.size(); ++k) {
auto selected_cost = selected_cost_list[k];
if (selected_cost == nullptr) {
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
return FAILED;
}
MS_EXCEPTION_IF_NULL(connected_components[k]);
auto one_operator = connected_components[k]->GetOperators()[0];
MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
MS_EXCEPTION_IF_NULL(decision);
one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
}
return SUCCESS;
}
}
MS_LOG(INFO) << "There are 2 nodes in the final graph.";
// In this case, the finale graph is exactly of the form: u --> v
MS_EXCEPTION_IF_NULL(u);
MS_EXCEPTION_IF_NULL(v);
auto e = u->GetAliveSuccEdges()[0];
MS_EXCEPTION_IF_NULL(e);
auto cost_list = CreateFinalCostList(u, e, v);
CostPtr cost = nullptr;
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
} else {
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
"phase is not supported.";
}
if (cost == nullptr) {
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
return FAILED;
}
MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
auto decision = cost->decision_ptr_->cast<FinalDecisionPtr>();
MS_EXCEPTION_IF_NULL(decision);
u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
e->set_selected_cost(decision->middle_cost_);
MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
return SUCCESS;
return SearchStrategyForTwoNodeFinalGraph(alive_ops);
}
}
@ -867,10 +713,11 @@ void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, con
op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_;
auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost);
auto new_cost = std::make_shared<Cost>(computation, communication, decision);
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
MS_EXCEPTION_IF_NULL(new_cost);
new_cost->communication_without_parameter_ = communication_without_para;
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
communication_without_para + gamma * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
MS_EXCEPTION_IF_NULL(op1_new_clist);
@ -879,6 +726,65 @@ void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, con
}
}
std::pair<std::vector<EdgePtr>, std::vector<EdgePtr>> UpdateEdgesIncidentToNodes(
OperatorInfoPtr op1, std::vector<EdgePtr> *op1_old_succ_edges,
std::vector<std::map<CostPtrKey, CostPtrList>> *op1_new_edges_cost, std::vector<EdgePtr> *op1_new_succ_edges,
OperatorInfoPtr op2, std::vector<EdgePtr> *op2_old_succ_edges,
std::vector<std::map<CostPtrKey, CostPtrList>> *op2_new_edges_cost, std::vector<EdgePtr> *op2_new_succ_edges) {
for (size_t i = 0; i < op1_old_succ_edges->size(); ++i) {
auto &new_cost_map = op1_new_edges_cost->at(i);
auto ith_edge = op1_old_succ_edges->at(i);
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name();
std::shared_ptr<Edge> new_edge;
if (ith_edge->is_combined()) {
std::vector<size_t> output_indexs, input_indexs;
output_indexs = ith_edge->prev_op_output_indexs();
input_indexs = ith_edge->next_op_input_indexs();
new_edge =
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true);
} else {
size_t output_index, input_index;
output_index = ith_edge->prev_op_output_index();
input_index = ith_edge->next_op_input_index();
new_edge =
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false);
}
new_edge->SetCostMapAndInputOutput(new_cost_map);
// replace the old successive edges with the new ones.
op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge);
ith_edge->next_operator()->ReplacePreEdge(op1, new_edge);
op1_new_succ_edges->erase(op1_new_succ_edges->begin() + i);
op1_new_succ_edges->emplace(op1_new_succ_edges->begin() + i, new_edge);
}
for (size_t i = 0; i < op2_old_succ_edges->size(); ++i) {
auto &new_cost_map = op2_new_edges_cost->at(i);
auto ith_edge = op2_old_succ_edges->at(i);
const auto &destination = ith_edge->next_operator();
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name();
std::shared_ptr<Edge> new_edge;
if (ith_edge->is_combined()) {
std::vector<size_t> output_indexs, input_indexs;
output_indexs = ith_edge->prev_op_output_indexs();
input_indexs = ith_edge->next_op_input_indexs();
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true);
} else {
size_t output_index, input_index;
output_index = ith_edge->prev_op_output_index();
input_index = ith_edge->next_op_input_index();
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false);
}
new_edge->SetCostMapAndInputOutput(new_cost_map);
// replace the old successive edges with the new ones.
destination->ReplacePreEdge(op2, new_edge);
op1->AddSuccEdge(new_edge);
op2_new_succ_edges->erase(op2_new_succ_edges->begin() + i);
op2_new_succ_edges->emplace(op2_new_succ_edges->begin() + i, new_edge);
}
return std::make_pair(*op1_new_succ_edges, *op2_new_succ_edges);
}
std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources(
OperatorInfoPtr op1, OperatorInfoPtr op2) {
MS_EXCEPTION_IF_NULL(op1);
@ -970,57 +876,9 @@ std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>
op2->SetNotAlive();
// Update the edges incident to op1, and edges incident to op2
for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
auto &new_cost_map = op1_new_edges_cost[i];
auto &ith_edge = op1_old_succ_edges[i];
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name();
std::shared_ptr<Edge> new_edge;
if (ith_edge->is_combined()) {
std::vector<size_t> output_indexs, input_indexs;
output_indexs = ith_edge->prev_op_output_indexs();
input_indexs = ith_edge->next_op_input_indexs();
new_edge =
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true);
} else {
size_t output_index, input_index;
output_index = ith_edge->prev_op_output_index();
input_index = ith_edge->next_op_input_index();
new_edge =
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false);
}
new_edge->SetCostMapAndInputOutput(new_cost_map);
// replace the old successive edges with the new ones.
op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge);
ith_edge->next_operator()->ReplacePreEdge(op1, new_edge);
op1_new_succ_edges[i] = new_edge;
}
for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
auto &new_cost_map = op2_new_edges_cost[i];
auto &ith_edge = op2_old_succ_edges[i];
const auto &destination = ith_edge->next_operator();
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name();
std::shared_ptr<Edge> new_edge;
if (ith_edge->is_combined()) {
std::vector<size_t> output_indexs, input_indexs;
output_indexs = ith_edge->prev_op_output_indexs();
input_indexs = ith_edge->next_op_input_indexs();
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true);
} else {
size_t output_index, input_index;
output_index = ith_edge->prev_op_output_index();
input_index = ith_edge->next_op_input_index();
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false);
}
new_edge->SetCostMapAndInputOutput(new_cost_map);
// replace the old successive edges with the new ones.
destination->ReplacePreEdge(op2, new_edge);
op1->AddSuccEdge(new_edge);
op2_new_succ_edges[i] = new_edge;
}
MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded.";
return {op1_new_succ_edges, op2_new_succ_edges};
return UpdateEdgesIncidentToNodes(op1, &op1_old_succ_edges, &op1_new_edges_cost, &op1_new_succ_edges, op2,
&op2_old_succ_edges, &op2_new_edges_cost, &op2_new_succ_edges);
}
// Check the graph whether a TriangleElimination can be performed
@ -1179,10 +1037,11 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
auto decision =
std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost);
auto new_cost = std::make_shared<Cost>(computation, communication, decision);
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
MS_EXCEPTION_IF_NULL(new_cost);
new_cost->communication_without_parameter_ = communication_without_para;
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
communication_without_para + gamma * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
MS_EXCEPTION_IF_NULL(tar_cost_list_new);
@ -1263,9 +1122,10 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost,
target_op_stra, tar_cost);
auto new_cost = std::make_shared<Cost>(computation, communication, decision);
auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
new_cost->communication_without_parameter_ = communication_without_para;
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
communication_without_para + gamma * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
tar_cost_list_new->emplace_back(std::move(new_cost));
@ -1338,8 +1198,10 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
double new_commu_without =
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
if (triangle_star_stra_overwrite) {
new_computation += right_op_cost->computation_cost_;
new_memory += right_op_cost->memory_with_reuse_;
new_commu_cost += right_op_cost->communication_cost_;
@ -1352,8 +1214,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
left_op_stra, left_node_cost, right_op_stra, right_op_cost);
auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
new_cost->communication_without_parameter_ = new_commu_without;
new_cost->communication_with_partial_para_ =
new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without);
new_cost->communication_with_partial_para_ = new_commu_without + gamma * (new_commu_cost - new_commu_without);
new_cost->memory_with_reuse_ = new_memory;
new_cost->communication_forward_ = new_commu_forward;
left_node_clist_new->emplace_back(std::move(new_cost));
@ -1463,6 +1324,8 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
commu_without = merged_node_cost->communication_without_parameter_,
commu_forward = merged_node_cost->communication_forward_;
const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
if (i == 0) {
@ -1478,7 +1341,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
commu_cost += succ_edges_costs[i]->communication_cost_;
commu_forward += succ_edges_costs[i]->communication_forward_;
commu_without += succ_edges_costs[i]->communication_without_parameter_;
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
if (triangle_star_stra_overwrite) {
computation_cost += succ_nodes_costs[i]->computation_cost_;
memory_cost += succ_nodes_costs[i]->memory_with_reuse_;
commu_cost += succ_nodes_costs[i]->communication_cost_;
@ -1492,7 +1355,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
succ_nodes_stras, succ_nodes_costs);
auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision);
new_cost->communication_without_parameter_ = commu_without;
new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
new_cost->communication_with_partial_para_ = commu_without + gamma * (commu_cost - commu_without);
new_cost->memory_with_reuse_ = memory_cost;
new_cost->communication_forward_ = commu_forward;
first_succ_node_clist_new->emplace_back(std::move(new_cost));
@ -1895,7 +1758,8 @@ Status CostGraph::CorrectOpsMemoryCost() {
}
Status CostGraph::CalculateMemoryCost() {
if (RUN_PHASE == TRAINING_PHASE) {
const auto phase = CostModelContext::GetInstance()->run_phase();
if (phase == TRAINING_PHASE) {
// training phase
if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
// Calculate operators' memory usage

View File

@ -34,32 +34,12 @@ class CostGraph;
using CostGraphPtr = std::shared_ptr<CostGraph>;
extern CostGraphPtr entire_costgraph;
extern size_t TOTAL_OPS;
extern double COST_MODEL_GAMMA;
extern bool COST_MODEL_SIMPLIFY_CALCULATION;
extern double DEVICE_MEMORY_CAPACITY;
extern double COST_MODEL_COMMUNI_THRESHOLD;
extern double COST_MODEL_COMMUNI_CONST;
extern double COST_MODEL_COMMUNI_BIAS;
extern bool TENSOR_SLICE_ALIGNMENT_ENABLE;
extern size_t TENSOR_SLICE_ALIGNMENT_SIZE;
extern bool FULLY_USE_DEVICES;
extern bool ELEMENTWISE_OP_STRA_FOLLOW;
extern bool MULTI_SUBGRAPHS;
extern bool DP_ALGO_ENABLE_APPROX;
extern double DP_ALGO_APPROX_EPSILON;
extern int64_t RUN_PHASE;
extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE;
extern bool DP_ALGO_SINGLE_LOOP;
class CostGraph {
// 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
// output-input dependency relationship.
public:
CostGraph() {
dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY;
costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA;
costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND;
}
CostGraph() {}
~CostGraph() = default;
void Init();
void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); }
@ -79,8 +59,6 @@ class CostGraph {
// An edge is uniquely identified by its name, and its output index and input index.
bool IsEdgeInCostGraph(const std::string &, size_t, size_t);
void SetDeviceMemoryAndCostParameter();
std::vector<std::shared_ptr<CostGraph>> ConstructConnectedComponents(std::vector<OperatorInfoPtr>);
void DFS(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
const std::shared_ptr<CostGraph> &component);
@ -91,10 +69,10 @@ class CostGraph {
CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory);
CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory);
Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &);
Status SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &);
std::vector<std::shared_ptr<Edge>> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) {
return edges_[{u_node, v_node}];
}
double GetDeviceMemory() const { return dev_memory_; }
// Search the cost_list in the final graph, and determine the optimal one
Status SearchStrategy();
@ -194,7 +172,7 @@ class CostGraph {
Status InitReshapeStrategy();
Status InitSelectedStrategy();
OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const;
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
// When TmpIdentity is used by multiple operators, the corresponding parameter's memory cost should be calculated only
// once (instead of multiple times), this method is used to correct this.
Status CorrectOpsMemoryCost();
// When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies.
@ -224,9 +202,6 @@ class CostGraph {
// Needed by rec_parser
std::vector<std::vector<std::string>> inputs_tensor_name_list_;
std::map<std::string, std::string> tuple_getitem_list_;
double dev_memory_;
double costmodel_alpha_;
double costmodel_beta_;
std::vector<OperatorInfoPtr> ops_;
std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_;
std::vector<std::shared_ptr<CostGraph>> connected_compoents_;

View File

@ -47,7 +47,7 @@ void CostModelContext::ResetCostModel() {
costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST;
costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS;
is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS;
run_phase_ = DEFAULT_RUN_PHASE;
run_phase_ = TRAINING_PHASE;
costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM;
costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES;
costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT;
@ -70,37 +70,115 @@ void CostModelContext::ResetAlgoParameters() {
dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON;
}
void CostModelContext::PrintCostModel() {
MS_LOG(INFO) << "device_memory_capacity: " << device_memory_capacity_ << ".";
MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << ".";
MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << ".";
MS_LOG(INFO) << "costmodel_gamma: " << costmodel_gamma_ << ".";
MS_LOG(INFO) << "costmodel_simplify_cal: " << costmodel_simplify_cal_ << ".";
MS_LOG(INFO) << "costmodel_communi_threshold: " << costmodel_communi_threshold_ << ".";
MS_LOG(INFO) << "costmodel_communi_const: " << costmodel_communi_const_ << ".";
MS_LOG(INFO) << "costmodel_communi_bias: " << costmodel_communi_bias_ << ".";
MS_LOG(INFO) << "is_multi_subgraphs: " << is_multi_subgraphs_ << ".";
MS_LOG(INFO) << "triangle_star_strategy_overwrite: " << triangle_star_strategy_overwrite_ << ".";
MS_LOG(INFO) << "dp_algo_enable_approxi: " << dp_algo_enable_approxi_ << ".";
MS_LOG(INFO) << "dp_algo_approxi_epsilon: " << dp_algo_approxi_epsilon_ << ".";
MS_LOG(INFO) << "dp_algo_single_loop: " << dp_algo_single_loop_ << ".";
MS_LOG(INFO) << "run_phase: " << run_phase_ << ".";
MS_LOG(INFO) << "tensor_slice_alignment_enable: " << tensor_slice_alignment_enable_ << ".";
MS_LOG(INFO) << "tensor_slice_align_size: " << tensor_slice_alignment_size_ << ".";
MS_LOG(INFO) << "fully_use_device: " << fully_use_device_ << ".";
MS_LOG(INFO) << "elementwise_stra_follow: " << elementwise_stra_follow_ << ".";
}
void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) {
if (device_target == kGPUDevice) {
costmodel_beta_ = DEFAULT_COST_MODEL_BETA_GPU;
}
}
void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { dp_algo_approxi_epsilon_ = epsilon; }
void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) {
if (epsilon <= 0 || epsilon > 1) {
MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]";
}
dp_algo_approxi_epsilon_ = epsilon;
}
void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { dp_algo_enable_approxi_ = approxi; }
void CostModelContext::set_dp_algo_enable_approxi(bool approxi) {
if (approxi) {
MS_LOG(INFO) << "dp_algo_enable_approx: true.";
} else {
MS_LOG(INFO) << "dp_algo_enable_approx: false.";
}
dp_algo_enable_approxi_ = approxi;
}
void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; }
void CostModelContext::set_device_memory_capacity(double dm_capacity) {
if (dm_capacity <= 0) {
MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive.";
}
device_memory_capacity_ = dm_capacity;
}
void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; }
void CostModelContext::set_costmodel_alpha(double cm_alpha) {
if (cm_alpha <= 0) {
MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive.";
}
costmodel_alpha_ = cm_alpha;
}
void CostModelContext::set_costmodel_beta(double cm_beta) { costmodel_beta_ = cm_beta; }
void CostModelContext::set_costmodel_beta(double cm_beta) {
if (cm_beta <= 0) {
MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive.";
}
costmodel_beta_ = cm_beta;
}
void CostModelContext::set_costmodel_gamma(double cm_gamma) { costmodel_gamma_ = cm_gamma; }
void CostModelContext::set_costmodel_gamma(double cm_gamma) {
if ((cm_gamma < 0) || (cm_gamma > 1)) {
MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1].";
}
costmodel_gamma_ = cm_gamma;
}
void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { costmodel_simplify_cal_ = cm_simplify; }
void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) {
if (cm_simplify) {
MS_LOG(INFO) << "costmodel_simplify_cal: true.";
} else {
MS_LOG(INFO) << "costmodel_simplify_cal: false.";
}
costmodel_simplify_cal_ = cm_simplify;
}
void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) {
if (cm_communi_th < 0) {
MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero.";
}
costmodel_communi_threshold_ = cm_communi_th;
}
void CostModelContext::set_costmodel_communi_const(double cm_communi_const) {
if (cm_communi_const < 0) {
MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero.";
}
costmodel_communi_const_ = cm_communi_const;
}
void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; }
void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) {
if (cm_communi_bias < 0) {
MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero.";
}
costmodel_communi_bias_ = cm_communi_bias;
}
void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; }
void CostModelContext::set_multi_subgraphs(bool multi_graphs) {
if (multi_graphs) {
MS_LOG(INFO) << "multi_subgraphs: true.";
} else {
MS_LOG(INFO) << "multi_subgraphs: false.";
}
is_multi_subgraphs_ = multi_graphs;
}
void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int64_t algorithm) {
costmodel_allreduce_fusion_algorithm_ = algorithm;
}
@ -129,25 +207,64 @@ void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter
costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter;
}
void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { tensor_slice_alignment_enable_ = ts_align; }
void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) {
if (ts_align) {
MS_LOG(INFO) << "tensor_slice_align_enable: true.";
} else {
MS_LOG(INFO) << "tensor_slice_align_enable: false.";
}
tensor_slice_alignment_enable_ = ts_align;
}
void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) {
if (ts_align_size == 0) {
MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive.";
}
tensor_slice_alignment_size_ = ts_align_size;
}
void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ = fully_use; }
void CostModelContext::set_fully_use_device(bool fully_use) {
if (fully_use) {
MS_LOG(INFO) << "fully_use_devices: true.";
} else {
MS_LOG(INFO) << "fully_use_devices: false.";
}
fully_use_device_ = fully_use;
}
void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) {
if (elementwise_follow) {
MS_LOG(INFO) << "elementwise_op_strategy_follow: true.";
} else {
MS_LOG(INFO) << "elementwise_op_strategy_follow: false.";
}
elementwise_stra_follow_ = elementwise_follow;
}
void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) {
if (overwrite) {
MS_LOG(INFO) << "triangle_star_strategy_overwrite: true.";
} else {
MS_LOG(INFO) << "triangle_star_strategy_overwrite: false.";
}
triangle_star_strategy_overwrite_ = overwrite;
}
void CostModelContext::set_run_phase(int64_t phase) { run_phase_ = phase; }
void CostModelContext::set_run_phase(int64_t phase) {
if (phase != 0 && phase != 1) {
MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}";
}
run_phase_ = phase;
}
void CostModelContext::set_dp_algo_single_loop(bool single_loop) { dp_algo_single_loop_ = single_loop; }
void CostModelContext::set_dp_algo_single_loop(bool single_loop) {
if (single_loop) {
MS_LOG(INFO) << "dp_algo_single_loop: true.";
} else {
MS_LOG(INFO) << "dp_algo_single_loop: false.";
}
dp_algo_single_loop_ = single_loop;
}
struct CostRegister {
CostRegister() {

View File

@ -41,7 +41,6 @@ namespace parallel {
#define DEFAULT_FULLY_USE_DEVICES true
#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
#define DEFAULT_IS_MULTI_SUBGRAPHS false
#define DEFAULT_RUN_PHASE 0
#define TRAINING_PHASE 0
#define INFERENCE_PHASE 1
#define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true;
@ -58,6 +57,7 @@ class CostModelContext {
void ResetAlgoParameters();
static std::shared_ptr<CostModelContext> GetInstance();
void PrintCostModel();
void set_costmodel_context_for_device(const std::string &);
// DEVICE_MEMORY_CAPACITY

View File

@ -475,7 +475,8 @@ Status MatMulBase::PrepareStrategy(int64_t stage_id, size_t dev_num,
size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) {
int64_t product =
std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int64_t>());
if (!FULLY_USE_DEVICES) {
const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
if (!fully_use_device) {
if (LongToSize(product) > dev_num) {
return FAILED;
}
@ -564,7 +565,9 @@ void MatMulBase::InitTensorInfoForCost(std::vector<TensorInfo> *relica_inputs_te
}
Status MatMulBase::CheckForTensorSliceValid() const {
if (!TENSOR_SLICE_ALIGNMENT_ENABLE) {
const auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable();
const auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size();
if (!align_enable) {
return SUCCESS;
}
if (inputs_tensor_info_.empty()) {
@ -572,8 +575,8 @@ Status MatMulBase::CheckForTensorSliceValid() const {
}
for (auto &one_input_tensor : inputs_tensor_info_) {
auto slice_shape = one_input_tensor.slice_shape();
if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) ||
(LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) {
if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % align_size != 0) ||
(LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % align_size != 0)) {
return FAILED;
}
}
@ -608,12 +611,12 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &
double computation_cost =
operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
result->communication_without_parameter_ =
operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
result->communication_with_partial_para_ =
result->communication_without_parameter_ +
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
// Breaking ties for preferring data parallelization
BreakingTiesForPerferringDataParallel(strategy, result);

View File

@ -839,7 +839,8 @@ Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &input
for (auto &input_partition : inputs_partitions) {
product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int64_t>());
}
if (!FULLY_USE_DEVICES) {
const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
if (!fully_use_device) {
if (LongToSize(product) > dev_num) {
return FAILED;
}
@ -1202,12 +1203,12 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
double computation_cost =
operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
result->communication_without_parameter_ =
operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
result->communication_with_partial_para_ =
result->communication_without_parameter_ +
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
// Breaking ties for preferring data parallelization
BreakingTiesForPerferringDataParallel(strategy, result);

View File

@ -396,12 +396,12 @@ void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &stra
double computation_cost =
operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
result->communication_without_parameter_ =
operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
result->communication_with_partial_para_ =
result->communication_without_parameter_ +
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
// Breaking ties for preferring data parallelization
BreakingTiesForPerferringDataParallel(strategy, result);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/parallel_stub/executor_manager_stub.h"
namespace mindspore {
namespace parallel {
@ -26,6 +27,5 @@ std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device
executors_[device_key] = executor;
return executor;
}
} // namespace parallel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_
#define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_
#include <set>

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H
#define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H

View File

@ -52,13 +52,18 @@ static bool IsInWhiteList(const CNodePtr &cnode) {
return false;
}
static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, size_t curr_depth) {
if (curr_depth > MAX_RECURSIVE_DEPTH) {
MS_LOG(WARNING) << "When setting the tags for Grad nodes, exceeded the maximum recursion depth: "
<< MAX_RECURSIVE_DEPTH;
return;
}
const auto &node_users = manager->node_users()[node];
for (auto &user_pair : node_users) {
auto user_node = user_pair.first;
if (!user_node->grad()) {
user_node->set_grad(true);
SetGradTag(user_node, manager);
SetGradTag(user_node, manager, ++curr_depth);
}
}
}
@ -69,7 +74,7 @@ void PipelineTransformer::LabelRequiredGradCNode() {
if (!ParameterRequireGrad(parameter)) {
continue;
}
SetGradTag(parameter, manager_);
SetGradTag(parameter, manager_, 0);
}
}

View File

@ -234,7 +234,8 @@ void InitCostGraph() {
if (entire_costgraph == nullptr) {
entire_costgraph = std::make_shared<CostGraph>();
}
entire_costgraph->SetDeviceMemoryAndCostParameter();
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
CostModelContext::GetInstance()->PrintCostModel();
entire_costgraph->Init();
}
@ -252,10 +253,11 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
if (prim->name() == RESHAPE) {
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
}
const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device();
// Set cost for this configured strategy
if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
} else if (FULLY_USE_DEVICES) {
} else if (fully_use_devices) {
// If configured to fully use devices, then checking for the user-specified strategy
int64_t used_devices = operator_info->used_devices();
MS_EXCEPTION_IF_NULL(g_device_manager);
@ -323,7 +325,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info->set_cnode(cnode);
// key of strategy map
std::string strategy_key_name = "";
auto param_names = NodeParameterName(cnode);
auto param_names = NodeParameterName(cnode, -1, 0);
if (!param_names.empty()) {
strategy_key_name = prim->name() + "_" + param_names[0].first;
}
@ -394,7 +396,8 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
if (search_cnode == from_cnode_to_info.end()) {
size_t loop_index = 0;
bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
if (DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
@ -430,7 +433,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
<< ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), operator_info));
if (DP_ALGO_SINGLE_LOOP && is_in_loop) {
if (single_loop && is_in_loop) {
operators_in_forloop.push_back(operator_info);
ops_in_a_loop_.insert(operator_info->name());
loop_to_ops[loop_index]++;
@ -511,7 +514,8 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
if (search_cnode == from_cnode_to_info.end()) {
size_t loop_index = 0;
bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
bool is_op_created = DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size());
const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size());
if (is_op_created) {
const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
@ -548,7 +552,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
<< ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
if (DP_ALGO_SINGLE_LOOP && is_in_loop) {
if (single_loop && is_in_loop) {
operators_in_forloop.push_back(operator_info);
ops_in_a_loop_.insert(operator_info->name());
loop_to_ops[loop_index]++;
@ -581,8 +585,9 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge.";
return;
}
const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
(ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()));
(stra_follow && IsElementWiseOperator(prev_prim->name()));
if (follow_strategy) {
// Redistribution in not allowed on the edge.
// Elementwise operators have the same strategy as their previous operators.
@ -1031,7 +1036,7 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
graph = EliminateGraph(graph, eli_list, index_list);
size_t num_device = g_device_manager->DeviceNum();
double device_memory = entire_costgraph->GetDeviceMemory();
const auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
} else {

View File

@ -1914,7 +1914,11 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
}
// find previous parallel care node's next node.
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes) {
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes, size_t curr_depth) {
if (curr_depth > MAX_RECURSIVE_DEPTH) {
MS_LOG(WARNING) << "When find the previous node, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
return false;
}
MS_EXCEPTION_IF_NULL(unique_ids);
MS_EXCEPTION_IF_NULL(indexes);
if (!node->isa<CNode>()) {
@ -1942,7 +1946,7 @@ bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vecto
find = true;
continue;
}
if (FindPreNodes(cnode, unique_ids, indexes)) {
if (FindPreNodes(cnode, unique_ids, indexes, ++curr_depth)) {
find = true;
continue;
}
@ -1954,7 +1958,7 @@ void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *u
std::vector<size_t> *indexes) {
MS_EXCEPTION_IF_NULL(unique_ids);
CNodePtr cnode = root->get_return();
if (!FindPreNodes(cnode, unique_ids, indexes)) {
if (!FindPreNodes(cnode, unique_ids, indexes, 0)) {
MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
}
}
@ -2044,7 +2048,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
// load strategy checkpoint
// key of strategy map
std::string strategy_key_name = "";
auto param_names = NodeParameterName(cnode);
auto param_names = NodeParameterName(cnode, -1, 0);
if (!param_names.empty()) {
strategy_key_name = prim->name() + "_" + param_names[0].first;
}
@ -2151,13 +2155,18 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
return nullptr;
}
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) {
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node, size_t curr_depth) {
if (curr_depth > MAX_RECURSIVE_DEPTH) {
MS_LOG(WARNING) << "When finding the next tensor layout for the parameter, exceeded the maximum recursion depth: "
<< MAX_RECURSIVE_DEPTH;
return nullptr;
}
FuncGraphManagerPtr manager = node->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[node];
for (auto &node_pair : node_set) {
if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) {
auto layout_param = FindParameterNextLayout(node_pair.first);
auto layout_param = FindParameterNextLayout(node_pair.first, ++curr_depth);
if (!layout_param) {
continue;
}
@ -2184,7 +2193,7 @@ std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) {
std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
// Create DataParallel tensor layout for parameter(support WideDeep).
auto next_layout = FindParameterNextLayout(node);
auto next_layout = FindParameterNextLayout(node, 0);
if (next_layout != nullptr) {
return next_layout;
}
@ -2329,14 +2338,19 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
}
}
CNodePtr HandleDependLoss(const CNodePtr &cnode) {
CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) {
if (curr_depth > MAX_RECURSIVE_DEPTH) {
MS_LOG(WARNING) << "When handling the loss node of Depend, exceeded the max recursive depth: "
<< MAX_RECURSIVE_DEPTH;
return nullptr;
}
// Handle return->depend->loss
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == DEPEND) {
auto depend_before = cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_before);
return HandleDependLoss(depend_before);
return HandleDependLoss(depend_before, ++curr_depth);
}
return cnode;
}
@ -2370,7 +2384,7 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
}
pre_cnode = HandleDependLoss(pre_cnode);
pre_cnode = HandleDependLoss(pre_cnode, 0);
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
// notice: the GetNext op has not input
@ -2792,7 +2806,12 @@ bool IsCohesiveNode(const CNodePtr &cnode) {
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather);
}
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index) {
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
if (curr_depth > MAX_RECURSIVE_DEPTH) {
MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: "
<< MAX_RECURSIVE_DEPTH;
return {};
}
std::vector<AnfNodePtr> node_inputs{node->inputs()};
std::vector<std::pair<std::string, int64_t>> param_names;
for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
@ -2809,7 +2828,7 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n
continue;
}
if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) {
auto input_param_names = NodeParameterName(cnode, idx);
auto input_param_names = NodeParameterName(cnode, idx, 0);
param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end());
}
}
@ -2827,7 +2846,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
auto param_names = NodeParameterName(cnode);
auto param_names = NodeParameterName(cnode, -1, 0);
if (param_names.empty()) {
continue;
}
@ -2950,13 +2969,17 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG
InsertNode(op, node, 2, pre_node, root, "shape");
}
static AnfNodePtr FindGrad(const CNodePtr &cnode) {
static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) {
if (curr_depth > MAX_RECURSIVE_DEPTH) {
MS_LOG(WARNING) << "When finding Grad nodes, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
return nullptr;
}
for (auto &node : cnode->inputs()) {
if (!node->isa<CNode>()) {
continue;
}
if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) {
return FindGrad(node->cast<CNodePtr>());
return FindGrad(node->cast<CNodePtr>(), ++curr_depth);
} else {
return node;
}
@ -2995,7 +3018,7 @@ void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes)
continue;
}
auto root = node->func_graph();
auto grad_node = FindGrad(cnode);
auto grad_node = FindGrad(cnode, 0);
if (grad_node) {
InsertShapeOp(cnode, grad_node, root);
}

View File

@ -139,7 +139,7 @@ bool IsLastStage();
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager);
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index = -1);
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes);

View File

@ -153,7 +153,6 @@ class TestDPAlgo : public UT::Common {
void TestDPAlgo::SetUp() {
cost_graph = std::make_shared<CostGraph>();
cost_graph->SetDeviceMemoryAndCostParameter();
RankList dev_list;
for (int32_t i = 0; i < 10; i++) {

View File

@ -52,7 +52,6 @@ class TestCostGraph : public UT::Common {
};
void TestCostGraph::SetUp() {
cost_graph.SetDeviceMemoryAndCostParameter();
RankList dev_list;
for (int32_t i = 0; i < 10; i++) {
@ -305,7 +304,6 @@ TEST_F(TestCostGraph, test_ConstructConnectedComponents) {
TEST_F(TestCostGraph, test_SelectCostListWithMinTrainingTimeMultiple) {
CostGraph entire_cost_graph;
entire_cost_graph.SetDeviceMemoryAndCostParameter();
double memory = 1024.0;
CostPtrList clist_1, clist_2;
std::vector<CostPtrList> all_list;
@ -371,7 +369,8 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) {
ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS);
cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2);
auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2);
cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory());
const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
cost_graph.SelectCostWithMinInferenceTime(cost_list, device_mem_capacity);
}
TEST_F(TestCostGraph, test_EliminationOp) {