forked from mindspore-Ecosystem/mindspore
!15777 [Auto parallel] Fix some codestyle warnings
From: @xiaoda_zh Reviewed-by: @yangzhenzhang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
6dd37c5857
|
@ -23,7 +23,8 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
void Simplify(CostPtrList *clist_ptrs) {
|
||||
if (RUN_PHASE == TRAINING_PHASE) {
|
||||
const auto run_phase = CostModelContext::GetInstance()->run_phase();
|
||||
if (run_phase == TRAINING_PHASE) {
|
||||
// training phase
|
||||
SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs);
|
||||
} else {
|
||||
|
@ -35,7 +36,8 @@ void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
|
|||
// Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method
|
||||
// excludes the cost with greater computation_cost_ and greater communication_forward.
|
||||
// E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
|
||||
if (!COST_MODEL_SIMPLIFY_CALCULATION) {
|
||||
const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal();
|
||||
if (!simplify_cal) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(clist_ptrs);
|
||||
|
@ -57,7 +59,8 @@ void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
|
|||
void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
|
||||
// Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
|
||||
// order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
|
||||
if (!COST_MODEL_SIMPLIFY_CALCULATION) {
|
||||
const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal();
|
||||
if (!simplify_cal) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(clist_ptrs);
|
||||
|
@ -78,19 +81,23 @@ void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs)
|
|||
|
||||
void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) {
|
||||
MS_EXCEPTION_IF_NULL(origin_cost);
|
||||
const auto comm_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold();
|
||||
const auto comm_const = CostModelContext::GetInstance()->costmodel_communi_const();
|
||||
const auto comm_bias = CostModelContext::GetInstance()->costmodel_communi_bias();
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
if (is_redistribution) {
|
||||
// Redistribution cost
|
||||
if ((origin_cost->communication_redis_forward_ > EPS) &&
|
||||
(origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
|
||||
origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST;
|
||||
} else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) {
|
||||
origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS;
|
||||
(origin_cost->communication_redis_forward_ <= comm_threshold)) {
|
||||
origin_cost->communication_redis_forward_ = comm_const;
|
||||
} else if (origin_cost->communication_redis_forward_ > comm_threshold) {
|
||||
origin_cost->communication_redis_forward_ += comm_bias;
|
||||
}
|
||||
if ((origin_cost->communication_redis_backward_ > EPS) &&
|
||||
(origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
|
||||
origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST;
|
||||
} else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) {
|
||||
origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS;
|
||||
(origin_cost->communication_redis_backward_ <= comm_threshold)) {
|
||||
origin_cost->communication_redis_backward_ = comm_const;
|
||||
} else if (origin_cost->communication_redis_backward_ > comm_threshold) {
|
||||
origin_cost->communication_redis_backward_ += comm_bias;
|
||||
}
|
||||
origin_cost->communication_cost_ =
|
||||
origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_;
|
||||
|
@ -104,18 +111,17 @@ void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution)
|
|||
}
|
||||
// forward cost
|
||||
if ((origin_cost->communication_without_parameter_ > EPS) &&
|
||||
(origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
|
||||
origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST;
|
||||
} else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) {
|
||||
origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS;
|
||||
(origin_cost->communication_without_parameter_ <= comm_threshold)) {
|
||||
origin_cost->communication_without_parameter_ = comm_const;
|
||||
} else if (origin_cost->communication_without_parameter_ > comm_threshold) {
|
||||
origin_cost->communication_without_parameter_ += comm_bias;
|
||||
}
|
||||
// total
|
||||
if (origin_cost->communication_cost_ > EPS) {
|
||||
origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward;
|
||||
}
|
||||
if (origin_cost->communication_with_partial_para_ > EPS) {
|
||||
origin_cost->communication_with_partial_para_ =
|
||||
origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward;
|
||||
origin_cost->communication_with_partial_para_ = origin_cost->communication_without_parameter_ + gamma * backward;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <vector>
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_info.h"
|
||||
#include "frontend/parallel/costmodel_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -44,8 +45,8 @@ using RedistributionOpListPtr = std::shared_ptr<std::pair<OperatorVector, OutPut
|
|||
|
||||
struct Cost {
|
||||
Cost();
|
||||
Cost(double computation, double commuication, const std::shared_ptr<Decision> &decision_ = nullptr)
|
||||
: computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) {
|
||||
Cost(double computation, double communication, const std::shared_ptr<Decision> &decision_ = nullptr)
|
||||
: computation_cost_(computation), communication_cost_(communication), decision_ptr_(std::move(decision_)) {
|
||||
memory_with_reuse_ = 0.0;
|
||||
communication_without_parameter_ = 0.0;
|
||||
communication_with_partial_para_ = 0.0;
|
||||
|
@ -273,7 +274,7 @@ struct TriangleEliminationDecision : public Decision {
|
|||
*
|
||||
* v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges.
|
||||
* In addition, v and w have other complicated connections, resulting in v and w can not be performed other
|
||||
* eliminations. After the StarElimination, u is merged into v, and the resulting graph is splitted into multiple
|
||||
* eliminations. After the StarElimination, u is merged into v, and the resulting graph is split into multiple
|
||||
* connected components.
|
||||
* NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
|
||||
*/
|
||||
|
|
|
@ -132,18 +132,14 @@ Status GetStrategy(const CostGraphPtr &graph) {
|
|||
|
||||
Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
||||
std::vector<EliminationPtr>::reverse_iterator rit;
|
||||
|
||||
const auto triangle_star_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
|
||||
for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) {
|
||||
if ((*rit)->isa<OpElimination>()) {
|
||||
auto elimination = (*rit)->cast<OpEliminationPtr>();
|
||||
auto e = elimination->new_edge_;
|
||||
auto w = elimination->op_;
|
||||
MS_EXCEPTION_IF_NULL(e);
|
||||
MS_EXCEPTION_IF_NULL(w);
|
||||
auto left_edge = elimination->left_edge_;
|
||||
auto right_edge = elimination->right_edge_;
|
||||
MS_EXCEPTION_IF_NULL(left_edge);
|
||||
MS_EXCEPTION_IF_NULL(right_edge);
|
||||
auto decision = e->selected_cost()->decision_ptr_->cast<OpEliminationDecisionPtr>();
|
||||
w->SetSelectedStrategyAndCost(decision->op_strategy_, decision->middle_cost_);
|
||||
left_edge->set_selected_cost(decision->left_cost_);
|
||||
|
@ -201,12 +197,12 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
|||
right_edge->set_selected_cost(decision->right_edge_cost_);
|
||||
// 'left_node' recovers the strategy.
|
||||
left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_);
|
||||
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
|
||||
if (triangle_star_overwrite) {
|
||||
// 'right_node' recovers the strategy.
|
||||
MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination.";
|
||||
right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_);
|
||||
} else {
|
||||
// In this case, 'right_node' is not overwriten strategy, and it checks strategy consistency.
|
||||
// In this case, 'right_node' is not overwritten strategy, and it checks strategy consistency.
|
||||
right_node->CheckSelectedStrategy(decision->right_node_strategy_);
|
||||
}
|
||||
MS_LOG(INFO) << "Recover triangleElimination succeeded.";
|
||||
|
@ -215,7 +211,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
|||
auto merged_node = elimination->eliminated_node_;
|
||||
auto succ_edges = elimination->succ_edges_;
|
||||
auto succ_nodes = elimination->succ_ops_;
|
||||
// decision is hided in succ_nodes[0]
|
||||
// decision is hidden in succ_nodes[0]
|
||||
auto decision = succ_nodes[0]->selected_cost()->decision_ptr_->cast<StarEliminationDecisionPtr>();
|
||||
|
||||
merged_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_);
|
||||
|
@ -228,7 +224,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
|||
// Star is eliminated into 'succ_nodes[0]'
|
||||
succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]);
|
||||
for (size_t k = 1; k < succ_nodes.size(); ++k) {
|
||||
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
|
||||
if (triangle_star_overwrite) {
|
||||
// 'succ_nodes[k]' is overwritten strategy and cost.
|
||||
succ_nodes[k]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[k], decision->succ_ops_cost_list_[k]);
|
||||
} else {
|
||||
|
|
|
@ -90,11 +90,13 @@ Status Edge::InitEdgeCost() {
|
|||
}
|
||||
}
|
||||
if (!has_available_cost) {
|
||||
if (FULLY_USE_DEVICES) {
|
||||
const auto fully_use = CostModelContext::GetInstance()->fully_use_device();
|
||||
const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
|
||||
if (fully_use) {
|
||||
MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
|
||||
<< " failed, it may be caused by setting 'fully_use_devices' true. Try to set "
|
||||
"'fully_use_devices' false.";
|
||||
} else if (ELEMENTWISE_OP_STRA_FOLLOW) {
|
||||
} else if (stra_follow) {
|
||||
MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
|
||||
<< " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. "
|
||||
"Try to set 'elementwise_op_strategy_follow' false.";
|
||||
|
@ -130,6 +132,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co
|
|||
double backward_comm_cost = tensor_redistribution.backward_comm_cost();
|
||||
double computation_cost = tensor_redistribution.computation_cost();
|
||||
double mem_cost = tensor_redistribution.memory_cost();
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
|
||||
// Now AllGather, ReduceScatter, AlltoAll don't support bool type
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
@ -142,7 +145,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co
|
|||
(*cost)->communication_without_parameter_ = type_length * comm_cost;
|
||||
(*cost)->communication_with_partial_para_ =
|
||||
(*cost)->communication_without_parameter_ +
|
||||
COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
|
||||
gamma * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
|
||||
(*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
|
||||
(*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
|
||||
(*cost)->memory_with_reuse_ = mem_cost;
|
||||
|
@ -173,13 +176,14 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
|
|||
std::function<void(size_t, double, double, double, double, double)> recursive =
|
||||
[&](size_t k, double computation, double memory, double communication, double communication_without_para,
|
||||
double communication_forward) {
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
if (k == edges.size()) {
|
||||
auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
|
||||
CostPtr new_cost = std::make_shared<Cost>(computation, communication);
|
||||
MS_EXCEPTION_IF_NULL(new_cost);
|
||||
new_cost->communication_without_parameter_ = communication_without_para;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
communication_without_para + gamma * (communication - communication_without_para);
|
||||
new_cost->memory_with_reuse_ = memory;
|
||||
new_cost->communication_forward_ = communication_forward;
|
||||
new_cost->decision_ptr_ = decision;
|
||||
|
@ -242,10 +246,11 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
|
|||
|
||||
auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
|
||||
auto cost = std::make_shared<Cost>(computation, communication, decision);
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
MS_EXCEPTION_IF_NULL(cost);
|
||||
cost->communication_without_parameter_ = communication_without_para;
|
||||
cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
communication_without_para + gamma * (communication - communication_without_para);
|
||||
cost->memory_with_reuse_ = memory_cost;
|
||||
cost->communication_forward_ = communication_forward;
|
||||
ret_cost_list->emplace_back(std::move(cost));
|
||||
|
|
|
@ -28,175 +28,6 @@ namespace mindspore {
|
|||
namespace parallel {
|
||||
CostGraphPtr entire_costgraph = nullptr;
|
||||
size_t TOTAL_OPS = 0;
|
||||
double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA;
|
||||
bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION;
|
||||
double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY;
|
||||
double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD;
|
||||
double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST;
|
||||
double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS;
|
||||
bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
|
||||
size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
|
||||
bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
|
||||
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
|
||||
int64_t RUN_PHASE = DEFAULT_RUN_PHASE;
|
||||
bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE;
|
||||
bool DP_ALGO_ENABLE_APPROX = DEFAULT_DP_ALGO_ENABLE_APPROX;
|
||||
double DP_ALGO_APPROX_EPSILON = DEFAULT_DP_ALGO_APPROX_EPSILON;
|
||||
bool DP_ALGO_SINGLE_LOOP = DEFAULT_DP_ALGO_SINGLE_LOOP;
|
||||
|
||||
void CostGraph::SetDeviceMemoryAndCostParameter() {
|
||||
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
|
||||
|
||||
// DEVICE_MEMORY_CAPACITY
|
||||
auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
|
||||
if (device_memory <= 0) {
|
||||
MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive.";
|
||||
}
|
||||
dev_memory_ = device_memory;
|
||||
DEVICE_MEMORY_CAPACITY = device_memory;
|
||||
MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << ".";
|
||||
|
||||
// COST_MODEL_ALPHA
|
||||
auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
|
||||
if (alpha <= 0) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive.";
|
||||
}
|
||||
costmodel_alpha_ = alpha;
|
||||
MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << ".";
|
||||
|
||||
// COST_MODEL_BETA
|
||||
auto beta = CostModelContext::GetInstance()->costmodel_beta();
|
||||
if (beta <= 0) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive.";
|
||||
}
|
||||
costmodel_beta_ = beta;
|
||||
MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << ".";
|
||||
|
||||
// COST_MODEL_GAMMA
|
||||
auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
if ((gamma < 0) || (gamma > 1)) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1].";
|
||||
}
|
||||
COST_MODEL_GAMMA = gamma;
|
||||
MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << ".";
|
||||
|
||||
// COST_MODEL_SIMPLIFY_CALCULATION
|
||||
auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal();
|
||||
COST_MODEL_SIMPLIFY_CALCULATION = simplify;
|
||||
if (COST_MODEL_SIMPLIFY_CALCULATION) {
|
||||
MS_LOG(INFO) << "costmodel_simplify_cal: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "costmodel_simplify_cal: false.";
|
||||
}
|
||||
|
||||
// COST_MODEL_COMMUNI_THRESHOLD
|
||||
auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold();
|
||||
if (communi_threshold < 0) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero.";
|
||||
}
|
||||
COST_MODEL_COMMUNI_THRESHOLD = communi_threshold;
|
||||
MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << ".";
|
||||
|
||||
// COST_MODEL_COMMUNI_CONST
|
||||
auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const();
|
||||
if (communi_const < 0) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero.";
|
||||
}
|
||||
COST_MODEL_COMMUNI_CONST = communi_const;
|
||||
MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << ".";
|
||||
|
||||
// COST_MODEL_COMMUNI_BIAS
|
||||
auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias();
|
||||
if (communi_bias < 0) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero.";
|
||||
}
|
||||
COST_MODEL_COMMUNI_BIAS = communi_bias;
|
||||
MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << ".";
|
||||
|
||||
// TENSOR_SLICE_ALIGNMENT_ENABLE
|
||||
auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable();
|
||||
TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable;
|
||||
if (TENSOR_SLICE_ALIGNMENT_ENABLE) {
|
||||
MS_LOG(INFO) << "tensor_slice_align_enable: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "tensor_slice_align_enable: false.";
|
||||
}
|
||||
|
||||
// TENSOR_SLICE_ALIGNMENT_SIZE
|
||||
auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size();
|
||||
if (align_size == 0) {
|
||||
MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive.";
|
||||
}
|
||||
TENSOR_SLICE_ALIGNMENT_SIZE = align_size;
|
||||
MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << ".";
|
||||
|
||||
// FULLY_USE_DEVICES
|
||||
auto fully_devices = CostModelContext::GetInstance()->fully_use_device();
|
||||
FULLY_USE_DEVICES = fully_devices;
|
||||
if (FULLY_USE_DEVICES) {
|
||||
MS_LOG(INFO) << "fully_use_devices: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "fully_use_devices: false.";
|
||||
}
|
||||
|
||||
// ELEMENTWISE_OP_STRA_FOLLOW
|
||||
auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
|
||||
ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow;
|
||||
if (ELEMENTWISE_OP_STRA_FOLLOW) {
|
||||
MS_LOG(INFO) << "elementwise_op_strategy_follow: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "elementwise_op_strategy_follow: false.";
|
||||
}
|
||||
|
||||
// MULTI_SUBGRAPHS
|
||||
auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs();
|
||||
MULTI_SUBGRAPHS = multi_subgraphs;
|
||||
if (MULTI_SUBGRAPHS) {
|
||||
MS_LOG(INFO) << "multi_subgraphs: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "multi_subgraphs: false.";
|
||||
}
|
||||
|
||||
auto overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
|
||||
TRIANGLE_STAR_STRATEGY_OVERWRITE = overwrite;
|
||||
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
|
||||
MS_LOG(INFO) << "triangle_star_strategy_overwrite: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "triangle_star_strategy_overwrite: false.";
|
||||
}
|
||||
|
||||
// RUN_PHASE
|
||||
auto phase = CostModelContext::GetInstance()->run_phase();
|
||||
if (phase != 0 && phase != 1) {
|
||||
MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}";
|
||||
}
|
||||
RUN_PHASE = phase;
|
||||
MS_LOG(INFO) << "run_phase: " << RUN_PHASE << ".";
|
||||
|
||||
auto enable_approx = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
||||
DP_ALGO_ENABLE_APPROX = enable_approx;
|
||||
if (enable_approx) {
|
||||
MS_LOG(INFO) << "dp_algo_enable_approx: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "dp_algo_enable_approx: false.";
|
||||
}
|
||||
|
||||
auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon();
|
||||
if (epsilon <= 0 || epsilon > 1) {
|
||||
MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]";
|
||||
}
|
||||
DP_ALGO_APPROX_EPSILON = epsilon;
|
||||
MS_LOG(INFO) << "epsilon: " << epsilon << ".";
|
||||
|
||||
auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
|
||||
DP_ALGO_SINGLE_LOOP = single_loop;
|
||||
if (single_loop) {
|
||||
MS_LOG(INFO) << "dp_algo_single_loop: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "dp_algo_single_loop: false.";
|
||||
}
|
||||
}
|
||||
|
||||
void CostGraph::Init() {
|
||||
inputs_tensor_name_list_.clear();
|
||||
|
@ -269,7 +100,6 @@ std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents(
|
|||
if ((!visited[op]) && op->is_alive()) {
|
||||
std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>();
|
||||
MS_EXCEPTION_IF_NULL(new_component);
|
||||
new_component->SetDeviceMemoryAndCostParameter();
|
||||
DFS(op, &visited, new_component);
|
||||
connected_compoents_.push_back(new_component);
|
||||
}
|
||||
|
@ -336,10 +166,11 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
|
|||
auto decision =
|
||||
std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
|
||||
auto cost = std::make_shared<Cost>(computation, communication, decision);
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
MS_EXCEPTION_IF_NULL(cost);
|
||||
cost->communication_without_parameter_ = communication_without_para;
|
||||
cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
communication_without_para + gamma * (communication - communication_without_para);
|
||||
cost->memory_with_reuse_ = memory;
|
||||
cost->communication_forward_ = communication_forward;
|
||||
ret.push_back(cost);
|
||||
|
@ -353,7 +184,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
|
|||
return ret;
|
||||
}
|
||||
|
||||
// Create final cost list for the graph containing a signle node: u
|
||||
// Create final cost list for the graph containing a single node: u
|
||||
CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
|
||||
MS_EXCEPTION_IF_NULL(u);
|
||||
CostPtrList ret;
|
||||
|
@ -365,11 +196,12 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
|
|||
MS_EXCEPTION_IF_NULL(cost1);
|
||||
auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1);
|
||||
auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision);
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
MS_EXCEPTION_IF_NULL(new_cost);
|
||||
new_cost->communication_without_parameter_ = cost1->communication_without_parameter_;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
cost1->communication_without_parameter_ +
|
||||
COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_);
|
||||
gamma * (cost1->communication_cost_ - cost1->communication_without_parameter_);
|
||||
new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
|
||||
new_cost->communication_forward_ = cost1->communication_forward_;
|
||||
ret.push_back(new_cost);
|
||||
|
@ -404,8 +236,10 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list,
|
|||
}
|
||||
// Init the returned value with first cost.
|
||||
CostPtr ret = after_mem_filter[0];
|
||||
const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
|
||||
const auto beta = CostModelContext::GetInstance()->costmodel_beta();
|
||||
|
||||
double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_;
|
||||
double minimum = alpha * ret->computation_cost_ + beta * ret->communication_forward_;
|
||||
MS_LOG(INFO) << "Cost 0: "
|
||||
<< "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
|
||||
<< ", communication_forward_: " << ret->communication_forward_
|
||||
|
@ -422,8 +256,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list,
|
|||
<< ", communication_cost_: " << after_mem_filter[i]->communication_cost_
|
||||
<< ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
|
||||
<< ".";
|
||||
auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
|
||||
costmodel_beta_ * after_mem_filter[i]->communication_forward_;
|
||||
auto tmp = alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_forward_;
|
||||
MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
|
||||
if (minimum > tmp) {
|
||||
minimum = tmp;
|
||||
|
@ -458,8 +291,10 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d
|
|||
}
|
||||
// Init the returned value with first cost.
|
||||
CostPtr ret = after_mem_filter[0];
|
||||
const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
|
||||
const auto beta = CostModelContext::GetInstance()->costmodel_beta();
|
||||
|
||||
double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_;
|
||||
double minimum = alpha * ret->computation_cost_ + beta * ret->communication_with_partial_para_;
|
||||
MS_LOG(INFO) << "Cost 0: "
|
||||
<< "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
|
||||
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
|
||||
|
@ -474,8 +309,8 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d
|
|||
<< ", communication_cost_: " << after_mem_filter[i]->communication_cost_
|
||||
<< ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
|
||||
<< ".";
|
||||
auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
|
||||
costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_;
|
||||
auto tmp =
|
||||
alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_with_partial_para_;
|
||||
MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
|
||||
if (minimum > tmp) {
|
||||
minimum = tmp;
|
||||
|
@ -513,14 +348,16 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect
|
|||
}
|
||||
|
||||
std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive,
|
||||
&available_memory, this](size_t k) {
|
||||
&available_memory](size_t k) {
|
||||
const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
|
||||
const auto beta = CostModelContext::GetInstance()->costmodel_beta();
|
||||
if (k == all_cost_list.size()) {
|
||||
double tmp_memory = 0.0, tmp_minimum = 0.0;
|
||||
for (size_t i = 0; i < selected_cost_list.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(selected_cost_list[i]);
|
||||
tmp_memory += selected_cost_list[i]->memory_with_reuse_;
|
||||
tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ +
|
||||
costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_;
|
||||
tmp_minimum += alpha * selected_cost_list[i]->computation_cost_ +
|
||||
beta * selected_cost_list[i]->communication_with_partial_para_;
|
||||
}
|
||||
MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum
|
||||
<< ".";
|
||||
|
@ -582,12 +419,12 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<Operato
|
|||
<< " operators in a component in the final graph.";
|
||||
}
|
||||
}
|
||||
//
|
||||
auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
|
||||
const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
|
||||
auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
|
||||
for (size_t k = 0; k < selected_cost_list.size(); ++k) {
|
||||
auto selected_cost = selected_cost_list[k];
|
||||
if (selected_cost == nullptr) {
|
||||
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
|
||||
MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
|
||||
return FAILED;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(connected_components[k]);
|
||||
|
@ -627,6 +464,99 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<Operato
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status CostGraph::SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
|
||||
// In this case, the final graph should contains exactly 2 nodes.
|
||||
if (alive_ops.empty()) {
|
||||
MS_LOG(INFO) << "0 Operator in the final graph.";
|
||||
return SUCCESS;
|
||||
}
|
||||
OperatorInfoPtr u, v;
|
||||
MS_EXCEPTION_IF_NULL(alive_ops[0]);
|
||||
MS_EXCEPTION_IF_NULL(alive_ops[1]);
|
||||
const auto phase = CostModelContext::GetInstance()->run_phase();
|
||||
const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
|
||||
if (!alive_ops[0]->GetAliveSuccEdges().empty() &&
|
||||
alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) {
|
||||
u = alive_ops[0];
|
||||
v = alive_ops[1];
|
||||
} else if (!alive_ops[1]->GetAliveSuccEdges().empty() &&
|
||||
alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) {
|
||||
u = alive_ops[1];
|
||||
v = alive_ops[0];
|
||||
} else {
|
||||
if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) {
|
||||
MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size()
|
||||
<< ", " << alive_ops[1]->GetAliveSuccEdges().size() << ".";
|
||||
} else {
|
||||
// In this case, the final graph consists of two single nodes
|
||||
MS_LOG(INFO) << "There are 2 single nodes in the final graph.";
|
||||
std::vector<CostPtrList> all_list;
|
||||
auto connected_components = ConstructConnectedComponents(alive_ops);
|
||||
MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
|
||||
for (size_t i = 0; i < connected_components.size(); ++i) {
|
||||
MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
|
||||
auto one_component = connected_components[i];
|
||||
MS_EXCEPTION_IF_NULL(one_component);
|
||||
auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
|
||||
all_list.push_back(cost_list);
|
||||
}
|
||||
CostPtrList selected_cost_list;
|
||||
if (phase == TRAINING_PHASE) {
|
||||
// training phase
|
||||
selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
|
||||
} else {
|
||||
// inference phase
|
||||
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
|
||||
"phase is not supported.";
|
||||
}
|
||||
for (size_t k = 0; k < selected_cost_list.size(); ++k) {
|
||||
auto selected_cost = selected_cost_list[k];
|
||||
if (selected_cost == nullptr) {
|
||||
MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity
|
||||
<< ".";
|
||||
return FAILED;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(connected_components[k]);
|
||||
auto one_operator = connected_components[k]->GetOperators()[0];
|
||||
MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
|
||||
auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
|
||||
MS_EXCEPTION_IF_NULL(decision);
|
||||
one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
|
||||
MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "There are 2 nodes in the final graph.";
|
||||
// In this case, the finale graph is exactly of the form: u --> v
|
||||
MS_EXCEPTION_IF_NULL(u);
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
auto e = u->GetAliveSuccEdges()[0];
|
||||
MS_EXCEPTION_IF_NULL(e);
|
||||
auto cost_list = CreateFinalCostList(u, e, v);
|
||||
CostPtr cost = nullptr;
|
||||
if (phase == TRAINING_PHASE) {
|
||||
// training phase
|
||||
cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
|
||||
"phase is not supported.";
|
||||
}
|
||||
if (cost == nullptr) {
|
||||
MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
|
||||
return FAILED;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
|
||||
auto decision = cost->decision_ptr_->cast<FinalDecisionPtr>();
|
||||
MS_EXCEPTION_IF_NULL(decision);
|
||||
u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
|
||||
v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
|
||||
e->set_selected_cost(decision->middle_cost_);
|
||||
MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// searching the strategy for the final eliminated graph
|
||||
Status CostGraph::SearchStrategy() {
|
||||
MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began.";
|
||||
|
@ -637,9 +567,11 @@ Status CostGraph::SearchStrategy() {
|
|||
alive_ops.push_back(op);
|
||||
}
|
||||
});
|
||||
const auto phase = CostModelContext::GetInstance()->run_phase();
|
||||
const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
|
||||
|
||||
if (alive_ops.size() > 2) {
|
||||
if (RUN_PHASE == TRAINING_PHASE) {
|
||||
if (phase == TRAINING_PHASE) {
|
||||
// training phase
|
||||
return SearchStrategyForMultiNodeFinalGraph(alive_ops);
|
||||
} else {
|
||||
|
@ -652,15 +584,15 @@ Status CostGraph::SearchStrategy() {
|
|||
OperatorInfoPtr u = alive_ops[0];
|
||||
auto cost_list = CreateFinalSingleCostList(u);
|
||||
CostPtr cost = nullptr;
|
||||
if (RUN_PHASE == TRAINING_PHASE) {
|
||||
if (phase == TRAINING_PHASE) {
|
||||
// training phase
|
||||
cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
|
||||
cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity);
|
||||
} else {
|
||||
// inference phase
|
||||
cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_);
|
||||
cost = SelectCostWithMinInferenceTime(cost_list, device_mem_capacity);
|
||||
}
|
||||
if (cost == nullptr) {
|
||||
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
|
||||
MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
|
||||
return FAILED;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(u);
|
||||
|
@ -671,93 +603,7 @@ Status CostGraph::SearchStrategy() {
|
|||
MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
|
||||
return SUCCESS;
|
||||
} else {
|
||||
// In this case, the final graph should contains exactly 2 nodes.
|
||||
if (alive_ops.empty()) {
|
||||
MS_LOG(INFO) << "0 Operator in the final graph.";
|
||||
return SUCCESS;
|
||||
}
|
||||
OperatorInfoPtr u, v;
|
||||
MS_EXCEPTION_IF_NULL(alive_ops[0]);
|
||||
MS_EXCEPTION_IF_NULL(alive_ops[1]);
|
||||
if (!alive_ops[0]->GetAliveSuccEdges().empty() &&
|
||||
alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) {
|
||||
u = alive_ops[0];
|
||||
v = alive_ops[1];
|
||||
} else if (!alive_ops[1]->GetAliveSuccEdges().empty() &&
|
||||
alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) {
|
||||
u = alive_ops[1];
|
||||
v = alive_ops[0];
|
||||
} else {
|
||||
if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) {
|
||||
MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size()
|
||||
<< ", " << alive_ops[1]->GetAliveSuccEdges().size() << ".";
|
||||
} else {
|
||||
// In this case, the final graph consists of two single nodes
|
||||
MS_LOG(INFO) << "There are 2 single nodes in the final graph.";
|
||||
std::vector<CostPtrList> all_list;
|
||||
auto connected_components = ConstructConnectedComponents(alive_ops);
|
||||
MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
|
||||
for (size_t i = 0; i < connected_components.size(); ++i) {
|
||||
MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
|
||||
auto one_component = connected_components[i];
|
||||
MS_EXCEPTION_IF_NULL(one_component);
|
||||
auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
|
||||
all_list.push_back(cost_list);
|
||||
}
|
||||
CostPtrList selected_cost_list;
|
||||
if (RUN_PHASE == TRAINING_PHASE) {
|
||||
// training phase
|
||||
selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
|
||||
} else {
|
||||
// inference phase
|
||||
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
|
||||
"phase is not supported.";
|
||||
}
|
||||
for (size_t k = 0; k < selected_cost_list.size(); ++k) {
|
||||
auto selected_cost = selected_cost_list[k];
|
||||
if (selected_cost == nullptr) {
|
||||
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
|
||||
return FAILED;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(connected_components[k]);
|
||||
auto one_operator = connected_components[k]->GetOperators()[0];
|
||||
MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
|
||||
auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
|
||||
MS_EXCEPTION_IF_NULL(decision);
|
||||
one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
|
||||
MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "There are 2 nodes in the final graph.";
|
||||
// In this case, the finale graph is exactly of the form: u --> v
|
||||
MS_EXCEPTION_IF_NULL(u);
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
auto e = u->GetAliveSuccEdges()[0];
|
||||
MS_EXCEPTION_IF_NULL(e);
|
||||
auto cost_list = CreateFinalCostList(u, e, v);
|
||||
CostPtr cost = nullptr;
|
||||
if (RUN_PHASE == TRAINING_PHASE) {
|
||||
// training phase
|
||||
cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
|
||||
"phase is not supported.";
|
||||
}
|
||||
if (cost == nullptr) {
|
||||
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
|
||||
return FAILED;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
|
||||
auto decision = cost->decision_ptr_->cast<FinalDecisionPtr>();
|
||||
MS_EXCEPTION_IF_NULL(decision);
|
||||
u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
|
||||
v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
|
||||
e->set_selected_cost(decision->middle_cost_);
|
||||
MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
|
||||
return SUCCESS;
|
||||
return SearchStrategyForTwoNodeFinalGraph(alive_ops);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -867,10 +713,11 @@ void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, con
|
|||
op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_;
|
||||
auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost);
|
||||
auto new_cost = std::make_shared<Cost>(computation, communication, decision);
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
MS_EXCEPTION_IF_NULL(new_cost);
|
||||
new_cost->communication_without_parameter_ = communication_without_para;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
communication_without_para + gamma * (communication - communication_without_para);
|
||||
new_cost->memory_with_reuse_ = memory;
|
||||
new_cost->communication_forward_ = communication_forward;
|
||||
MS_EXCEPTION_IF_NULL(op1_new_clist);
|
||||
|
@ -879,6 +726,65 @@ void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, con
|
|||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<EdgePtr>, std::vector<EdgePtr>> UpdateEdgesIncidentToNodes(
|
||||
OperatorInfoPtr op1, std::vector<EdgePtr> *op1_old_succ_edges,
|
||||
std::vector<std::map<CostPtrKey, CostPtrList>> *op1_new_edges_cost, std::vector<EdgePtr> *op1_new_succ_edges,
|
||||
OperatorInfoPtr op2, std::vector<EdgePtr> *op2_old_succ_edges,
|
||||
std::vector<std::map<CostPtrKey, CostPtrList>> *op2_new_edges_cost, std::vector<EdgePtr> *op2_new_succ_edges) {
|
||||
for (size_t i = 0; i < op1_old_succ_edges->size(); ++i) {
|
||||
auto &new_cost_map = op1_new_edges_cost->at(i);
|
||||
auto ith_edge = op1_old_succ_edges->at(i);
|
||||
|
||||
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name();
|
||||
std::shared_ptr<Edge> new_edge;
|
||||
if (ith_edge->is_combined()) {
|
||||
std::vector<size_t> output_indexs, input_indexs;
|
||||
output_indexs = ith_edge->prev_op_output_indexs();
|
||||
input_indexs = ith_edge->next_op_input_indexs();
|
||||
new_edge =
|
||||
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true);
|
||||
} else {
|
||||
size_t output_index, input_index;
|
||||
output_index = ith_edge->prev_op_output_index();
|
||||
input_index = ith_edge->next_op_input_index();
|
||||
new_edge =
|
||||
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false);
|
||||
}
|
||||
new_edge->SetCostMapAndInputOutput(new_cost_map);
|
||||
// replace the old successive edges with the new ones.
|
||||
op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge);
|
||||
ith_edge->next_operator()->ReplacePreEdge(op1, new_edge);
|
||||
op1_new_succ_edges->erase(op1_new_succ_edges->begin() + i);
|
||||
op1_new_succ_edges->emplace(op1_new_succ_edges->begin() + i, new_edge);
|
||||
}
|
||||
for (size_t i = 0; i < op2_old_succ_edges->size(); ++i) {
|
||||
auto &new_cost_map = op2_new_edges_cost->at(i);
|
||||
auto ith_edge = op2_old_succ_edges->at(i);
|
||||
const auto &destination = ith_edge->next_operator();
|
||||
|
||||
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name();
|
||||
std::shared_ptr<Edge> new_edge;
|
||||
if (ith_edge->is_combined()) {
|
||||
std::vector<size_t> output_indexs, input_indexs;
|
||||
output_indexs = ith_edge->prev_op_output_indexs();
|
||||
input_indexs = ith_edge->next_op_input_indexs();
|
||||
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true);
|
||||
} else {
|
||||
size_t output_index, input_index;
|
||||
output_index = ith_edge->prev_op_output_index();
|
||||
input_index = ith_edge->next_op_input_index();
|
||||
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false);
|
||||
}
|
||||
new_edge->SetCostMapAndInputOutput(new_cost_map);
|
||||
// replace the old successive edges with the new ones.
|
||||
destination->ReplacePreEdge(op2, new_edge);
|
||||
op1->AddSuccEdge(new_edge);
|
||||
op2_new_succ_edges->erase(op2_new_succ_edges->begin() + i);
|
||||
op2_new_succ_edges->emplace(op2_new_succ_edges->begin() + i, new_edge);
|
||||
}
|
||||
return std::make_pair(*op1_new_succ_edges, *op2_new_succ_edges);
|
||||
}
|
||||
|
||||
std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources(
|
||||
OperatorInfoPtr op1, OperatorInfoPtr op2) {
|
||||
MS_EXCEPTION_IF_NULL(op1);
|
||||
|
@ -970,57 +876,9 @@ std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>
|
|||
op2->SetNotAlive();
|
||||
|
||||
// Update the edges incident to op1, and edges incident to op2
|
||||
for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
|
||||
auto &new_cost_map = op1_new_edges_cost[i];
|
||||
auto &ith_edge = op1_old_succ_edges[i];
|
||||
|
||||
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name();
|
||||
std::shared_ptr<Edge> new_edge;
|
||||
if (ith_edge->is_combined()) {
|
||||
std::vector<size_t> output_indexs, input_indexs;
|
||||
output_indexs = ith_edge->prev_op_output_indexs();
|
||||
input_indexs = ith_edge->next_op_input_indexs();
|
||||
new_edge =
|
||||
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true);
|
||||
} else {
|
||||
size_t output_index, input_index;
|
||||
output_index = ith_edge->prev_op_output_index();
|
||||
input_index = ith_edge->next_op_input_index();
|
||||
new_edge =
|
||||
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false);
|
||||
}
|
||||
new_edge->SetCostMapAndInputOutput(new_cost_map);
|
||||
// replace the old successive edges with the new ones.
|
||||
op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge);
|
||||
ith_edge->next_operator()->ReplacePreEdge(op1, new_edge);
|
||||
op1_new_succ_edges[i] = new_edge;
|
||||
}
|
||||
for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
|
||||
auto &new_cost_map = op2_new_edges_cost[i];
|
||||
auto &ith_edge = op2_old_succ_edges[i];
|
||||
const auto &destination = ith_edge->next_operator();
|
||||
|
||||
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name();
|
||||
std::shared_ptr<Edge> new_edge;
|
||||
if (ith_edge->is_combined()) {
|
||||
std::vector<size_t> output_indexs, input_indexs;
|
||||
output_indexs = ith_edge->prev_op_output_indexs();
|
||||
input_indexs = ith_edge->next_op_input_indexs();
|
||||
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true);
|
||||
} else {
|
||||
size_t output_index, input_index;
|
||||
output_index = ith_edge->prev_op_output_index();
|
||||
input_index = ith_edge->next_op_input_index();
|
||||
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false);
|
||||
}
|
||||
new_edge->SetCostMapAndInputOutput(new_cost_map);
|
||||
// replace the old successive edges with the new ones.
|
||||
destination->ReplacePreEdge(op2, new_edge);
|
||||
op1->AddSuccEdge(new_edge);
|
||||
op2_new_succ_edges[i] = new_edge;
|
||||
}
|
||||
MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded.";
|
||||
return {op1_new_succ_edges, op2_new_succ_edges};
|
||||
return UpdateEdgesIncidentToNodes(op1, &op1_old_succ_edges, &op1_new_edges_cost, &op1_new_succ_edges, op2,
|
||||
&op2_old_succ_edges, &op2_new_edges_cost, &op2_new_succ_edges);
|
||||
}
|
||||
|
||||
// Check the graph whether a TriangleElimination can be performed
|
||||
|
@ -1179,10 +1037,11 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
|
|||
auto decision =
|
||||
std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost);
|
||||
auto new_cost = std::make_shared<Cost>(computation, communication, decision);
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
MS_EXCEPTION_IF_NULL(new_cost);
|
||||
new_cost->communication_without_parameter_ = communication_without_para;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
communication_without_para + gamma * (communication - communication_without_para);
|
||||
new_cost->memory_with_reuse_ = memory;
|
||||
new_cost->communication_forward_ = communication_forward;
|
||||
MS_EXCEPTION_IF_NULL(tar_cost_list_new);
|
||||
|
@ -1263,9 +1122,10 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
|
|||
auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost,
|
||||
target_op_stra, tar_cost);
|
||||
auto new_cost = std::make_shared<Cost>(computation, communication, decision);
|
||||
auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
new_cost->communication_without_parameter_ = communication_without_para;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
communication_without_para + gamma * (communication - communication_without_para);
|
||||
new_cost->memory_with_reuse_ = memory;
|
||||
new_cost->communication_forward_ = communication_forward;
|
||||
tar_cost_list_new->emplace_back(std::move(new_cost));
|
||||
|
@ -1338,8 +1198,10 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
|
|||
double new_commu_without =
|
||||
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
|
||||
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
|
||||
const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
|
||||
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
|
||||
if (triangle_star_stra_overwrite) {
|
||||
new_computation += right_op_cost->computation_cost_;
|
||||
new_memory += right_op_cost->memory_with_reuse_;
|
||||
new_commu_cost += right_op_cost->communication_cost_;
|
||||
|
@ -1352,8 +1214,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
|
|||
left_op_stra, left_node_cost, right_op_stra, right_op_cost);
|
||||
auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
|
||||
new_cost->communication_without_parameter_ = new_commu_without;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without);
|
||||
new_cost->communication_with_partial_para_ = new_commu_without + gamma * (new_commu_cost - new_commu_without);
|
||||
new_cost->memory_with_reuse_ = new_memory;
|
||||
new_cost->communication_forward_ = new_commu_forward;
|
||||
left_node_clist_new->emplace_back(std::move(new_cost));
|
||||
|
@ -1463,6 +1324,8 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
|
|||
memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
|
||||
commu_without = merged_node_cost->communication_without_parameter_,
|
||||
commu_forward = merged_node_cost->communication_forward_;
|
||||
const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
|
||||
if (i == 0) {
|
||||
|
@ -1478,7 +1341,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
|
|||
commu_cost += succ_edges_costs[i]->communication_cost_;
|
||||
commu_forward += succ_edges_costs[i]->communication_forward_;
|
||||
commu_without += succ_edges_costs[i]->communication_without_parameter_;
|
||||
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
|
||||
if (triangle_star_stra_overwrite) {
|
||||
computation_cost += succ_nodes_costs[i]->computation_cost_;
|
||||
memory_cost += succ_nodes_costs[i]->memory_with_reuse_;
|
||||
commu_cost += succ_nodes_costs[i]->communication_cost_;
|
||||
|
@ -1492,7 +1355,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
|
|||
succ_nodes_stras, succ_nodes_costs);
|
||||
auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision);
|
||||
new_cost->communication_without_parameter_ = commu_without;
|
||||
new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
|
||||
new_cost->communication_with_partial_para_ = commu_without + gamma * (commu_cost - commu_without);
|
||||
new_cost->memory_with_reuse_ = memory_cost;
|
||||
new_cost->communication_forward_ = commu_forward;
|
||||
first_succ_node_clist_new->emplace_back(std::move(new_cost));
|
||||
|
@ -1895,7 +1758,8 @@ Status CostGraph::CorrectOpsMemoryCost() {
|
|||
}
|
||||
|
||||
Status CostGraph::CalculateMemoryCost() {
|
||||
if (RUN_PHASE == TRAINING_PHASE) {
|
||||
const auto phase = CostModelContext::GetInstance()->run_phase();
|
||||
if (phase == TRAINING_PHASE) {
|
||||
// training phase
|
||||
if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
|
||||
// Calculate operators' memory usage
|
||||
|
|
|
@ -34,32 +34,12 @@ class CostGraph;
|
|||
using CostGraphPtr = std::shared_ptr<CostGraph>;
|
||||
extern CostGraphPtr entire_costgraph;
|
||||
extern size_t TOTAL_OPS;
|
||||
extern double COST_MODEL_GAMMA;
|
||||
extern bool COST_MODEL_SIMPLIFY_CALCULATION;
|
||||
extern double DEVICE_MEMORY_CAPACITY;
|
||||
extern double COST_MODEL_COMMUNI_THRESHOLD;
|
||||
extern double COST_MODEL_COMMUNI_CONST;
|
||||
extern double COST_MODEL_COMMUNI_BIAS;
|
||||
extern bool TENSOR_SLICE_ALIGNMENT_ENABLE;
|
||||
extern size_t TENSOR_SLICE_ALIGNMENT_SIZE;
|
||||
extern bool FULLY_USE_DEVICES;
|
||||
extern bool ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
extern bool MULTI_SUBGRAPHS;
|
||||
extern bool DP_ALGO_ENABLE_APPROX;
|
||||
extern double DP_ALGO_APPROX_EPSILON;
|
||||
extern int64_t RUN_PHASE;
|
||||
extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE;
|
||||
extern bool DP_ALGO_SINGLE_LOOP;
|
||||
|
||||
class CostGraph {
|
||||
// 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
|
||||
// output-input dependency relationship.
|
||||
public:
|
||||
CostGraph() {
|
||||
dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY;
|
||||
costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA;
|
||||
costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND;
|
||||
}
|
||||
CostGraph() {}
|
||||
~CostGraph() = default;
|
||||
void Init();
|
||||
void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); }
|
||||
|
@ -79,8 +59,6 @@ class CostGraph {
|
|||
// An edge is uniquely identified by its name, and its output index and input index.
|
||||
bool IsEdgeInCostGraph(const std::string &, size_t, size_t);
|
||||
|
||||
void SetDeviceMemoryAndCostParameter();
|
||||
|
||||
std::vector<std::shared_ptr<CostGraph>> ConstructConnectedComponents(std::vector<OperatorInfoPtr>);
|
||||
void DFS(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited,
|
||||
const std::shared_ptr<CostGraph> &component);
|
||||
|
@ -91,10 +69,10 @@ class CostGraph {
|
|||
CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory);
|
||||
CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory);
|
||||
Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &);
|
||||
Status SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &);
|
||||
std::vector<std::shared_ptr<Edge>> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) {
|
||||
return edges_[{u_node, v_node}];
|
||||
}
|
||||
double GetDeviceMemory() const { return dev_memory_; }
|
||||
|
||||
// Search the cost_list in the final graph, and determine the optimal one
|
||||
Status SearchStrategy();
|
||||
|
@ -194,7 +172,7 @@ class CostGraph {
|
|||
Status InitReshapeStrategy();
|
||||
Status InitSelectedStrategy();
|
||||
OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const;
|
||||
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
|
||||
// When TmpIdentity is used by multiple operators, the corresponding parameter's memory cost should be calculated only
|
||||
// once (instead of multiple times), this method is used to correct this.
|
||||
Status CorrectOpsMemoryCost();
|
||||
// When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies.
|
||||
|
@ -224,9 +202,6 @@ class CostGraph {
|
|||
// Needed by rec_parser
|
||||
std::vector<std::vector<std::string>> inputs_tensor_name_list_;
|
||||
std::map<std::string, std::string> tuple_getitem_list_;
|
||||
double dev_memory_;
|
||||
double costmodel_alpha_;
|
||||
double costmodel_beta_;
|
||||
std::vector<OperatorInfoPtr> ops_;
|
||||
std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_;
|
||||
std::vector<std::shared_ptr<CostGraph>> connected_compoents_;
|
||||
|
|
|
@ -47,7 +47,7 @@ void CostModelContext::ResetCostModel() {
|
|||
costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST;
|
||||
costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS;
|
||||
is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS;
|
||||
run_phase_ = DEFAULT_RUN_PHASE;
|
||||
run_phase_ = TRAINING_PHASE;
|
||||
costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM;
|
||||
costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES;
|
||||
costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT;
|
||||
|
@ -70,37 +70,115 @@ void CostModelContext::ResetAlgoParameters() {
|
|||
dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON;
|
||||
}
|
||||
|
||||
void CostModelContext::PrintCostModel() {
|
||||
MS_LOG(INFO) << "device_memory_capacity: " << device_memory_capacity_ << ".";
|
||||
MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << ".";
|
||||
MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << ".";
|
||||
MS_LOG(INFO) << "costmodel_gamma: " << costmodel_gamma_ << ".";
|
||||
MS_LOG(INFO) << "costmodel_simplify_cal: " << costmodel_simplify_cal_ << ".";
|
||||
MS_LOG(INFO) << "costmodel_communi_threshold: " << costmodel_communi_threshold_ << ".";
|
||||
MS_LOG(INFO) << "costmodel_communi_const: " << costmodel_communi_const_ << ".";
|
||||
MS_LOG(INFO) << "costmodel_communi_bias: " << costmodel_communi_bias_ << ".";
|
||||
MS_LOG(INFO) << "is_multi_subgraphs: " << is_multi_subgraphs_ << ".";
|
||||
MS_LOG(INFO) << "triangle_star_strategy_overwrite: " << triangle_star_strategy_overwrite_ << ".";
|
||||
MS_LOG(INFO) << "dp_algo_enable_approxi: " << dp_algo_enable_approxi_ << ".";
|
||||
MS_LOG(INFO) << "dp_algo_approxi_epsilon: " << dp_algo_approxi_epsilon_ << ".";
|
||||
MS_LOG(INFO) << "dp_algo_single_loop: " << dp_algo_single_loop_ << ".";
|
||||
MS_LOG(INFO) << "run_phase: " << run_phase_ << ".";
|
||||
MS_LOG(INFO) << "tensor_slice_alignment_enable: " << tensor_slice_alignment_enable_ << ".";
|
||||
MS_LOG(INFO) << "tensor_slice_align_size: " << tensor_slice_alignment_size_ << ".";
|
||||
MS_LOG(INFO) << "fully_use_device: " << fully_use_device_ << ".";
|
||||
MS_LOG(INFO) << "elementwise_stra_follow: " << elementwise_stra_follow_ << ".";
|
||||
}
|
||||
|
||||
void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) {
|
||||
if (device_target == kGPUDevice) {
|
||||
costmodel_beta_ = DEFAULT_COST_MODEL_BETA_GPU;
|
||||
}
|
||||
}
|
||||
|
||||
void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { dp_algo_approxi_epsilon_ = epsilon; }
|
||||
void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) {
|
||||
if (epsilon <= 0 || epsilon > 1) {
|
||||
MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]";
|
||||
}
|
||||
dp_algo_approxi_epsilon_ = epsilon;
|
||||
}
|
||||
|
||||
void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { dp_algo_enable_approxi_ = approxi; }
|
||||
void CostModelContext::set_dp_algo_enable_approxi(bool approxi) {
|
||||
if (approxi) {
|
||||
MS_LOG(INFO) << "dp_algo_enable_approx: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "dp_algo_enable_approx: false.";
|
||||
}
|
||||
dp_algo_enable_approxi_ = approxi;
|
||||
}
|
||||
|
||||
void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; }
|
||||
void CostModelContext::set_device_memory_capacity(double dm_capacity) {
|
||||
if (dm_capacity <= 0) {
|
||||
MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive.";
|
||||
}
|
||||
device_memory_capacity_ = dm_capacity;
|
||||
}
|
||||
|
||||
void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; }
|
||||
void CostModelContext::set_costmodel_alpha(double cm_alpha) {
|
||||
if (cm_alpha <= 0) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive.";
|
||||
}
|
||||
costmodel_alpha_ = cm_alpha;
|
||||
}
|
||||
|
||||
void CostModelContext::set_costmodel_beta(double cm_beta) { costmodel_beta_ = cm_beta; }
|
||||
void CostModelContext::set_costmodel_beta(double cm_beta) {
|
||||
if (cm_beta <= 0) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive.";
|
||||
}
|
||||
costmodel_beta_ = cm_beta;
|
||||
}
|
||||
|
||||
void CostModelContext::set_costmodel_gamma(double cm_gamma) { costmodel_gamma_ = cm_gamma; }
|
||||
void CostModelContext::set_costmodel_gamma(double cm_gamma) {
|
||||
if ((cm_gamma < 0) || (cm_gamma > 1)) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1].";
|
||||
}
|
||||
costmodel_gamma_ = cm_gamma;
|
||||
}
|
||||
|
||||
void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { costmodel_simplify_cal_ = cm_simplify; }
|
||||
void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) {
|
||||
if (cm_simplify) {
|
||||
MS_LOG(INFO) << "costmodel_simplify_cal: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "costmodel_simplify_cal: false.";
|
||||
}
|
||||
costmodel_simplify_cal_ = cm_simplify;
|
||||
}
|
||||
|
||||
void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) {
|
||||
if (cm_communi_th < 0) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero.";
|
||||
}
|
||||
costmodel_communi_threshold_ = cm_communi_th;
|
||||
}
|
||||
|
||||
void CostModelContext::set_costmodel_communi_const(double cm_communi_const) {
|
||||
if (cm_communi_const < 0) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero.";
|
||||
}
|
||||
costmodel_communi_const_ = cm_communi_const;
|
||||
}
|
||||
|
||||
void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; }
|
||||
void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) {
|
||||
if (cm_communi_bias < 0) {
|
||||
MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero.";
|
||||
}
|
||||
costmodel_communi_bias_ = cm_communi_bias;
|
||||
}
|
||||
|
||||
void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; }
|
||||
void CostModelContext::set_multi_subgraphs(bool multi_graphs) {
|
||||
if (multi_graphs) {
|
||||
MS_LOG(INFO) << "multi_subgraphs: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "multi_subgraphs: false.";
|
||||
}
|
||||
is_multi_subgraphs_ = multi_graphs;
|
||||
}
|
||||
void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int64_t algorithm) {
|
||||
costmodel_allreduce_fusion_algorithm_ = algorithm;
|
||||
}
|
||||
|
@ -129,25 +207,64 @@ void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter
|
|||
costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter;
|
||||
}
|
||||
|
||||
void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { tensor_slice_alignment_enable_ = ts_align; }
|
||||
void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) {
|
||||
if (ts_align) {
|
||||
MS_LOG(INFO) << "tensor_slice_align_enable: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "tensor_slice_align_enable: false.";
|
||||
}
|
||||
tensor_slice_alignment_enable_ = ts_align;
|
||||
}
|
||||
|
||||
void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) {
|
||||
if (ts_align_size == 0) {
|
||||
MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive.";
|
||||
}
|
||||
tensor_slice_alignment_size_ = ts_align_size;
|
||||
}
|
||||
|
||||
void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ = fully_use; }
|
||||
void CostModelContext::set_fully_use_device(bool fully_use) {
|
||||
if (fully_use) {
|
||||
MS_LOG(INFO) << "fully_use_devices: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "fully_use_devices: false.";
|
||||
}
|
||||
fully_use_device_ = fully_use;
|
||||
}
|
||||
|
||||
void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) {
|
||||
if (elementwise_follow) {
|
||||
MS_LOG(INFO) << "elementwise_op_strategy_follow: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "elementwise_op_strategy_follow: false.";
|
||||
}
|
||||
elementwise_stra_follow_ = elementwise_follow;
|
||||
}
|
||||
|
||||
void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) {
|
||||
if (overwrite) {
|
||||
MS_LOG(INFO) << "triangle_star_strategy_overwrite: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "triangle_star_strategy_overwrite: false.";
|
||||
}
|
||||
triangle_star_strategy_overwrite_ = overwrite;
|
||||
}
|
||||
|
||||
void CostModelContext::set_run_phase(int64_t phase) { run_phase_ = phase; }
|
||||
void CostModelContext::set_run_phase(int64_t phase) {
|
||||
if (phase != 0 && phase != 1) {
|
||||
MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}";
|
||||
}
|
||||
run_phase_ = phase;
|
||||
}
|
||||
|
||||
void CostModelContext::set_dp_algo_single_loop(bool single_loop) { dp_algo_single_loop_ = single_loop; }
|
||||
void CostModelContext::set_dp_algo_single_loop(bool single_loop) {
|
||||
if (single_loop) {
|
||||
MS_LOG(INFO) << "dp_algo_single_loop: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "dp_algo_single_loop: false.";
|
||||
}
|
||||
dp_algo_single_loop_ = single_loop;
|
||||
}
|
||||
|
||||
struct CostRegister {
|
||||
CostRegister() {
|
||||
|
|
|
@ -41,7 +41,6 @@ namespace parallel {
|
|||
#define DEFAULT_FULLY_USE_DEVICES true
|
||||
#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
|
||||
#define DEFAULT_IS_MULTI_SUBGRAPHS false
|
||||
#define DEFAULT_RUN_PHASE 0
|
||||
#define TRAINING_PHASE 0
|
||||
#define INFERENCE_PHASE 1
|
||||
#define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true;
|
||||
|
@ -58,6 +57,7 @@ class CostModelContext {
|
|||
void ResetAlgoParameters();
|
||||
|
||||
static std::shared_ptr<CostModelContext> GetInstance();
|
||||
void PrintCostModel();
|
||||
|
||||
void set_costmodel_context_for_device(const std::string &);
|
||||
// DEVICE_MEMORY_CAPACITY
|
||||
|
|
|
@ -475,7 +475,8 @@ Status MatMulBase::PrepareStrategy(int64_t stage_id, size_t dev_num,
|
|||
size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) {
|
||||
int64_t product =
|
||||
std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int64_t>());
|
||||
if (!FULLY_USE_DEVICES) {
|
||||
const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
|
||||
if (!fully_use_device) {
|
||||
if (LongToSize(product) > dev_num) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -564,7 +565,9 @@ void MatMulBase::InitTensorInfoForCost(std::vector<TensorInfo> *relica_inputs_te
|
|||
}
|
||||
|
||||
Status MatMulBase::CheckForTensorSliceValid() const {
|
||||
if (!TENSOR_SLICE_ALIGNMENT_ENABLE) {
|
||||
const auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable();
|
||||
const auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size();
|
||||
if (!align_enable) {
|
||||
return SUCCESS;
|
||||
}
|
||||
if (inputs_tensor_info_.empty()) {
|
||||
|
@ -572,8 +575,8 @@ Status MatMulBase::CheckForTensorSliceValid() const {
|
|||
}
|
||||
for (auto &one_input_tensor : inputs_tensor_info_) {
|
||||
auto slice_shape = one_input_tensor.slice_shape();
|
||||
if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) ||
|
||||
(LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) {
|
||||
if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % align_size != 0) ||
|
||||
(LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % align_size != 0)) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -608,12 +611,12 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &
|
|||
double computation_cost =
|
||||
operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
|
||||
result->communication_without_parameter_ =
|
||||
operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
result->communication_with_partial_para_ =
|
||||
result->communication_without_parameter_ +
|
||||
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
|
||||
result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
|
||||
|
||||
// Breaking ties for preferring data parallelization
|
||||
BreakingTiesForPerferringDataParallel(strategy, result);
|
||||
|
|
|
@ -839,7 +839,8 @@ Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &input
|
|||
for (auto &input_partition : inputs_partitions) {
|
||||
product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int64_t>());
|
||||
}
|
||||
if (!FULLY_USE_DEVICES) {
|
||||
const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
|
||||
if (!fully_use_device) {
|
||||
if (LongToSize(product) > dev_num) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -1202,12 +1203,12 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
|
|||
double computation_cost =
|
||||
operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
|
||||
result->communication_without_parameter_ =
|
||||
operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
result->communication_with_partial_para_ =
|
||||
result->communication_without_parameter_ +
|
||||
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
|
||||
result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
|
||||
|
||||
// Breaking ties for preferring data parallelization
|
||||
BreakingTiesForPerferringDataParallel(strategy, result);
|
||||
|
|
|
@ -396,12 +396,12 @@ void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &stra
|
|||
double computation_cost =
|
||||
operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
|
||||
result->communication_without_parameter_ =
|
||||
operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
result->communication_with_partial_para_ =
|
||||
result->communication_without_parameter_ +
|
||||
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
|
||||
result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
|
||||
|
||||
// Breaking ties for preferring data parallelization
|
||||
BreakingTiesForPerferringDataParallel(strategy, result);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,6 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/parallel_stub/executor_manager_stub.h"
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -26,6 +27,5 @@ std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device
|
|||
executors_[device_key] = executor;
|
||||
return executor;
|
||||
}
|
||||
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,6 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_
|
||||
#define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_
|
||||
#include <set>
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,6 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H
|
||||
#define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H
|
||||
|
||||
|
|
|
@ -52,13 +52,18 @@ static bool IsInWhiteList(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
|
||||
static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, size_t curr_depth) {
|
||||
if (curr_depth > MAX_RECURSIVE_DEPTH) {
|
||||
MS_LOG(WARNING) << "When setting the tags for Grad nodes, exceeded the maximum recursion depth: "
|
||||
<< MAX_RECURSIVE_DEPTH;
|
||||
return;
|
||||
}
|
||||
const auto &node_users = manager->node_users()[node];
|
||||
for (auto &user_pair : node_users) {
|
||||
auto user_node = user_pair.first;
|
||||
if (!user_node->grad()) {
|
||||
user_node->set_grad(true);
|
||||
SetGradTag(user_node, manager);
|
||||
SetGradTag(user_node, manager, ++curr_depth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -69,7 +74,7 @@ void PipelineTransformer::LabelRequiredGradCNode() {
|
|||
if (!ParameterRequireGrad(parameter)) {
|
||||
continue;
|
||||
}
|
||||
SetGradTag(parameter, manager_);
|
||||
SetGradTag(parameter, manager_, 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -234,7 +234,8 @@ void InitCostGraph() {
|
|||
if (entire_costgraph == nullptr) {
|
||||
entire_costgraph = std::make_shared<CostGraph>();
|
||||
}
|
||||
entire_costgraph->SetDeviceMemoryAndCostParameter();
|
||||
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
|
||||
CostModelContext::GetInstance()->PrintCostModel();
|
||||
entire_costgraph->Init();
|
||||
}
|
||||
|
||||
|
@ -252,10 +253,11 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
|
|||
if (prim->name() == RESHAPE) {
|
||||
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
|
||||
}
|
||||
const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device();
|
||||
// Set cost for this configured strategy
|
||||
if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
|
||||
} else if (FULLY_USE_DEVICES) {
|
||||
} else if (fully_use_devices) {
|
||||
// If configured to fully use devices, then checking for the user-specified strategy
|
||||
int64_t used_devices = operator_info->used_devices();
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
|
@ -323,7 +325,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
operator_info->set_cnode(cnode);
|
||||
// key of strategy map
|
||||
std::string strategy_key_name = "";
|
||||
auto param_names = NodeParameterName(cnode);
|
||||
auto param_names = NodeParameterName(cnode, -1, 0);
|
||||
if (!param_names.empty()) {
|
||||
strategy_key_name = prim->name() + "_" + param_names[0].first;
|
||||
}
|
||||
|
@ -394,7 +396,8 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
if (search_cnode == from_cnode_to_info.end()) {
|
||||
size_t loop_index = 0;
|
||||
bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
|
||||
if (DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
|
||||
const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
|
||||
if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
|
||||
const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
|
||||
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
|
||||
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
|
||||
|
@ -430,7 +433,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
<< ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
|
||||
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
|
||||
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), operator_info));
|
||||
if (DP_ALGO_SINGLE_LOOP && is_in_loop) {
|
||||
if (single_loop && is_in_loop) {
|
||||
operators_in_forloop.push_back(operator_info);
|
||||
ops_in_a_loop_.insert(operator_info->name());
|
||||
loop_to_ops[loop_index]++;
|
||||
|
@ -511,7 +514,8 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
if (search_cnode == from_cnode_to_info.end()) {
|
||||
size_t loop_index = 0;
|
||||
bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
|
||||
bool is_op_created = DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size());
|
||||
const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
|
||||
bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size());
|
||||
if (is_op_created) {
|
||||
const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
|
||||
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
|
||||
|
@ -548,7 +552,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
<< ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
|
||||
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
|
||||
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
|
||||
if (DP_ALGO_SINGLE_LOOP && is_in_loop) {
|
||||
if (single_loop && is_in_loop) {
|
||||
operators_in_forloop.push_back(operator_info);
|
||||
ops_in_a_loop_.insert(operator_info->name());
|
||||
loop_to_ops[loop_index]++;
|
||||
|
@ -581,8 +585,9 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
|
|||
MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge.";
|
||||
return;
|
||||
}
|
||||
const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
|
||||
bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
|
||||
(ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()));
|
||||
(stra_follow && IsElementWiseOperator(prev_prim->name()));
|
||||
if (follow_strategy) {
|
||||
// Redistribution in not allowed on the edge.
|
||||
// Elementwise operators have the same strategy as their previous operators.
|
||||
|
@ -1031,7 +1036,7 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
|
|||
graph = EliminateGraph(graph, eli_list, index_list);
|
||||
|
||||
size_t num_device = g_device_manager->DeviceNum();
|
||||
double device_memory = entire_costgraph->GetDeviceMemory();
|
||||
const auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
|
||||
if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
|
||||
MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
|
||||
} else {
|
||||
|
|
|
@ -1914,7 +1914,11 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
|||
}
|
||||
|
||||
// find previous parallel care node's next node.
|
||||
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes) {
|
||||
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes, size_t curr_depth) {
|
||||
if (curr_depth > MAX_RECURSIVE_DEPTH) {
|
||||
MS_LOG(WARNING) << "When find the previous node, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(unique_ids);
|
||||
MS_EXCEPTION_IF_NULL(indexes);
|
||||
if (!node->isa<CNode>()) {
|
||||
|
@ -1942,7 +1946,7 @@ bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vecto
|
|||
find = true;
|
||||
continue;
|
||||
}
|
||||
if (FindPreNodes(cnode, unique_ids, indexes)) {
|
||||
if (FindPreNodes(cnode, unique_ids, indexes, ++curr_depth)) {
|
||||
find = true;
|
||||
continue;
|
||||
}
|
||||
|
@ -1954,7 +1958,7 @@ void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *u
|
|||
std::vector<size_t> *indexes) {
|
||||
MS_EXCEPTION_IF_NULL(unique_ids);
|
||||
CNodePtr cnode = root->get_return();
|
||||
if (!FindPreNodes(cnode, unique_ids, indexes)) {
|
||||
if (!FindPreNodes(cnode, unique_ids, indexes, 0)) {
|
||||
MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
|
||||
}
|
||||
}
|
||||
|
@ -2044,7 +2048,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
|
|||
// load strategy checkpoint
|
||||
// key of strategy map
|
||||
std::string strategy_key_name = "";
|
||||
auto param_names = NodeParameterName(cnode);
|
||||
auto param_names = NodeParameterName(cnode, -1, 0);
|
||||
if (!param_names.empty()) {
|
||||
strategy_key_name = prim->name() + "_" + param_names[0].first;
|
||||
}
|
||||
|
@ -2151,13 +2155,18 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) {
|
||||
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node, size_t curr_depth) {
|
||||
if (curr_depth > MAX_RECURSIVE_DEPTH) {
|
||||
MS_LOG(WARNING) << "When finding the next tensor layout for the parameter, exceeded the maximum recursion depth: "
|
||||
<< MAX_RECURSIVE_DEPTH;
|
||||
return nullptr;
|
||||
}
|
||||
FuncGraphManagerPtr manager = node->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
AnfNodeIndexSet node_set = manager->node_users()[node];
|
||||
for (auto &node_pair : node_set) {
|
||||
if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) {
|
||||
auto layout_param = FindParameterNextLayout(node_pair.first);
|
||||
auto layout_param = FindParameterNextLayout(node_pair.first, ++curr_depth);
|
||||
if (!layout_param) {
|
||||
continue;
|
||||
}
|
||||
|
@ -2184,7 +2193,7 @@ std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) {
|
|||
|
||||
std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
|
||||
// Create DataParallel tensor layout for parameter(support WideDeep).
|
||||
auto next_layout = FindParameterNextLayout(node);
|
||||
auto next_layout = FindParameterNextLayout(node, 0);
|
||||
if (next_layout != nullptr) {
|
||||
return next_layout;
|
||||
}
|
||||
|
@ -2329,14 +2338,19 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
}
|
||||
|
||||
CNodePtr HandleDependLoss(const CNodePtr &cnode) {
|
||||
CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) {
|
||||
if (curr_depth > MAX_RECURSIVE_DEPTH) {
|
||||
MS_LOG(WARNING) << "When handling the loss node of Depend, exceeded the max recursive depth: "
|
||||
<< MAX_RECURSIVE_DEPTH;
|
||||
return nullptr;
|
||||
}
|
||||
// Handle return->depend->loss
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() == DEPEND) {
|
||||
auto depend_before = cnode->input(1)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_before);
|
||||
return HandleDependLoss(depend_before);
|
||||
return HandleDependLoss(depend_before, ++curr_depth);
|
||||
}
|
||||
return cnode;
|
||||
}
|
||||
|
@ -2370,7 +2384,7 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
|
|||
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(pre_cnode);
|
||||
}
|
||||
pre_cnode = HandleDependLoss(pre_cnode);
|
||||
pre_cnode = HandleDependLoss(pre_cnode, 0);
|
||||
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||
|
||||
// notice: the GetNext op has not input
|
||||
|
@ -2792,7 +2806,12 @@ bool IsCohesiveNode(const CNodePtr &cnode) {
|
|||
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather);
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index) {
|
||||
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
|
||||
if (curr_depth > MAX_RECURSIVE_DEPTH) {
|
||||
MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: "
|
||||
<< MAX_RECURSIVE_DEPTH;
|
||||
return {};
|
||||
}
|
||||
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
||||
std::vector<std::pair<std::string, int64_t>> param_names;
|
||||
for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
|
||||
|
@ -2809,7 +2828,7 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n
|
|||
continue;
|
||||
}
|
||||
if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) {
|
||||
auto input_param_names = NodeParameterName(cnode, idx);
|
||||
auto input_param_names = NodeParameterName(cnode, idx, 0);
|
||||
param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end());
|
||||
}
|
||||
}
|
||||
|
@ -2827,7 +2846,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
auto param_names = NodeParameterName(cnode);
|
||||
auto param_names = NodeParameterName(cnode, -1, 0);
|
||||
if (param_names.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -2950,13 +2969,17 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG
|
|||
InsertNode(op, node, 2, pre_node, root, "shape");
|
||||
}
|
||||
|
||||
static AnfNodePtr FindGrad(const CNodePtr &cnode) {
|
||||
static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) {
|
||||
if (curr_depth > MAX_RECURSIVE_DEPTH) {
|
||||
MS_LOG(WARNING) << "When finding Grad nodes, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
|
||||
return nullptr;
|
||||
}
|
||||
for (auto &node : cnode->inputs()) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) {
|
||||
return FindGrad(node->cast<CNodePtr>());
|
||||
return FindGrad(node->cast<CNodePtr>(), ++curr_depth);
|
||||
} else {
|
||||
return node;
|
||||
}
|
||||
|
@ -2995,7 +3018,7 @@ void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes)
|
|||
continue;
|
||||
}
|
||||
auto root = node->func_graph();
|
||||
auto grad_node = FindGrad(cnode);
|
||||
auto grad_node = FindGrad(cnode, 0);
|
||||
if (grad_node) {
|
||||
InsertShapeOp(cnode, grad_node, root);
|
||||
}
|
||||
|
|
|
@ -139,7 +139,7 @@ bool IsLastStage();
|
|||
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
||||
const FuncGraphManagerPtr &manager);
|
||||
|
||||
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index = -1);
|
||||
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
|
||||
|
||||
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes);
|
||||
|
||||
|
|
|
@ -153,7 +153,6 @@ class TestDPAlgo : public UT::Common {
|
|||
|
||||
void TestDPAlgo::SetUp() {
|
||||
cost_graph = std::make_shared<CostGraph>();
|
||||
cost_graph->SetDeviceMemoryAndCostParameter();
|
||||
RankList dev_list;
|
||||
|
||||
for (int32_t i = 0; i < 10; i++) {
|
||||
|
|
|
@ -52,7 +52,6 @@ class TestCostGraph : public UT::Common {
|
|||
};
|
||||
|
||||
void TestCostGraph::SetUp() {
|
||||
cost_graph.SetDeviceMemoryAndCostParameter();
|
||||
RankList dev_list;
|
||||
|
||||
for (int32_t i = 0; i < 10; i++) {
|
||||
|
@ -305,7 +304,6 @@ TEST_F(TestCostGraph, test_ConstructConnectedComponents) {
|
|||
|
||||
TEST_F(TestCostGraph, test_SelectCostListWithMinTrainingTimeMultiple) {
|
||||
CostGraph entire_cost_graph;
|
||||
entire_cost_graph.SetDeviceMemoryAndCostParameter();
|
||||
double memory = 1024.0;
|
||||
CostPtrList clist_1, clist_2;
|
||||
std::vector<CostPtrList> all_list;
|
||||
|
@ -371,7 +369,8 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) {
|
|||
ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS);
|
||||
cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2);
|
||||
auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2);
|
||||
cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory());
|
||||
const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
|
||||
cost_graph.SelectCostWithMinInferenceTime(cost_list, device_mem_capacity);
|
||||
}
|
||||
|
||||
TEST_F(TestCostGraph, test_EliminationOp) {
|
||||
|
|
Loading…
Reference in New Issue