forked from mindspore-Ecosystem/mindspore
calculating PEAK memory cost in the inference phase
This commit is contained in:
parent
552fc5c958
commit
a05aa21cc2
|
@ -53,7 +53,7 @@ struct Cost {
|
||||||
communication_redis_backward_ = 0.0;
|
communication_redis_backward_ = 0.0;
|
||||||
communication_forward_ = 0.0;
|
communication_forward_ = 0.0;
|
||||||
}
|
}
|
||||||
// 'memory_with_reuse_' calculates the peak memory usage in a training phase
|
// 'memory_with_reuse_' calculates the peak memory usage in a training (or inference) phase
|
||||||
double memory_with_reuse_;
|
double memory_with_reuse_;
|
||||||
// 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated
|
// 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated
|
||||||
// by ONLY forward phase
|
// by ONLY forward phase
|
||||||
|
|
|
@ -300,5 +300,20 @@ Status Edge::CalculateMemoryCost() {
|
||||||
|
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Edge::CalculateMemoryCostForInference() {
|
||||||
|
// Currently, memory cost is NOT calculated for redistribution
|
||||||
|
if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) {
|
||||||
|
MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
for (auto &cost_kv : cost_map_) {
|
||||||
|
auto &cost_v = cost_kv.second;
|
||||||
|
if (!cost_v.empty()) {
|
||||||
|
cost_v[0]->memory_with_reuse_ = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -131,9 +131,13 @@ class Edge {
|
||||||
void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; }
|
void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; }
|
||||||
const CostPtr &selected_cost() const { return selected_cost_; }
|
const CostPtr &selected_cost() const { return selected_cost_; }
|
||||||
void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; }
|
void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; }
|
||||||
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
|
// In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving
|
||||||
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
|
// WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory
|
||||||
|
// at the end of forward phase.
|
||||||
Status CalculateMemoryCost();
|
Status CalculateMemoryCost();
|
||||||
|
// In the inference phase,
|
||||||
|
Status CalculateMemoryCostForInference();
|
||||||
|
void mark_output_critical() { is_output_critical_ = 1; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string edge_name_;
|
std::string edge_name_;
|
||||||
|
@ -156,7 +160,11 @@ class Edge {
|
||||||
// If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor.
|
// If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor.
|
||||||
bool is_identity_edge;
|
bool is_identity_edge;
|
||||||
CostPtr selected_cost_;
|
CostPtr selected_cost_;
|
||||||
|
// In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator
|
||||||
|
// is parameter-involved
|
||||||
int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
|
int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
|
||||||
|
// In the inference phase, this is used to mark whether the output of the previous operator is critical.
|
||||||
|
int is_output_critical_ = 0;
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -369,7 +369,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list,
|
||||||
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
|
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
|
||||||
<< ", communication_cost_: " << ret->communication_cost_
|
<< ", communication_cost_: " << ret->communication_cost_
|
||||||
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
|
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
|
||||||
MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
|
MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
|
||||||
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
|
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
|
||||||
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
|
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
|
||||||
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
|
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
|
||||||
|
@ -422,7 +422,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d
|
||||||
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
|
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
|
||||||
<< ", communication_cost_: " << ret->communication_cost_
|
<< ", communication_cost_: " << ret->communication_cost_
|
||||||
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
|
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
|
||||||
MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
|
MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
|
||||||
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
|
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
|
||||||
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
|
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
|
||||||
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
|
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
|
||||||
|
@ -1351,6 +1351,14 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
|
||||||
return succ_edges;
|
return succ_edges;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t CostGraph::GetNumEdges() const {
|
||||||
|
size_t sum = 0;
|
||||||
|
for (const auto &kv : edges_) {
|
||||||
|
auto &edges = kv.second;
|
||||||
|
sum += edges.size();
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
Status CostGraph::InitSelectedStrategy() {
|
Status CostGraph::InitSelectedStrategy() {
|
||||||
for (auto &op : ops_) {
|
for (auto &op : ops_) {
|
||||||
MS_EXCEPTION_IF_NULL(op);
|
MS_EXCEPTION_IF_NULL(op);
|
||||||
|
@ -1416,6 +1424,122 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CostGraph::DFSForTopoOrder(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited,
|
||||||
|
std::vector<OperatorInfoPtr> *topo_order) {
|
||||||
|
MS_EXCEPTION_IF_NULL(current_op);
|
||||||
|
MS_EXCEPTION_IF_NULL(visited);
|
||||||
|
MS_EXCEPTION_IF_NULL(topo_order);
|
||||||
|
|
||||||
|
visited->at(current_op) = true;
|
||||||
|
for (const auto &s_edge : current_op->succ_edges()) {
|
||||||
|
if (!visited->at(s_edge->next_operator())) {
|
||||||
|
DFSForTopoOrder(s_edge->next_operator(), visited, topo_order);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
topo_order->push_back(current_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute a topological order of the costgraph
|
||||||
|
void CostGraph::TopologyOrder(std::vector<OperatorInfoPtr> *topo_order) {
|
||||||
|
std::map<OperatorInfoPtr, bool> visited;
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
visited[op] = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
if (!visited[op]) {
|
||||||
|
DFSForTopoOrder(op, &visited, topo_order);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void CostGraph::MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &candidate_ops) {
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
auto search = candidate_ops.find(op);
|
||||||
|
if (search != candidate_ops.end()) {
|
||||||
|
// Mark the critical operators
|
||||||
|
op->mark_output_critical();
|
||||||
|
// Mark the successive edges
|
||||||
|
for (auto &s_edge : op->succ_edges()) {
|
||||||
|
s_edge->mark_output_critical();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
op->mark_output_not_critical();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CostGraph::DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order) {
|
||||||
|
if (topo_order.size() == 0) {
|
||||||
|
MS_LOG(ERROR) << "0 operator in costgraph.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
auto &first_op = topo_order[0];
|
||||||
|
if (first_op->prev_edges().size() > 0) {
|
||||||
|
MS_LOG(ERROR) << "The first operator in the first of topological order of "
|
||||||
|
"costgraph should have 0 incoming edge, but has "
|
||||||
|
<< first_op->prev_edges() << "edges.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
// The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number
|
||||||
|
// of the output of OperatorInfo that currently has not been used
|
||||||
|
std::map<OperatorInfoPtr, int> curr_memory_state;
|
||||||
|
(void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size())));
|
||||||
|
std::map<OperatorInfoPtr, int> max_memory_state = curr_memory_state;
|
||||||
|
// The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has
|
||||||
|
// not been used
|
||||||
|
double curr_memory_size = first_op->GetOutputsTotalSize();
|
||||||
|
double max_memory_size = curr_memory_size;
|
||||||
|
|
||||||
|
for (size_t finished = 1; finished < topo_order.size(); ++finished) {
|
||||||
|
// Produce
|
||||||
|
(void)curr_memory_state.emplace(
|
||||||
|
std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size())));
|
||||||
|
curr_memory_size += topo_order[finished]->GetOutputsTotalSize();
|
||||||
|
// Consume
|
||||||
|
for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
|
||||||
|
const auto &prev_op = prev_edge->prev_operator();
|
||||||
|
curr_memory_state[prev_op]--;
|
||||||
|
}
|
||||||
|
for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
|
||||||
|
const auto &prev_op = prev_edge->prev_operator();
|
||||||
|
if (curr_memory_state[prev_op] < 0) {
|
||||||
|
MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op];
|
||||||
|
return FAILED;
|
||||||
|
} else if (curr_memory_state[prev_op] == 0) {
|
||||||
|
curr_memory_state.erase(prev_op);
|
||||||
|
curr_memory_size -= prev_op->GetOutputsTotalSize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (curr_memory_size < 0) {
|
||||||
|
MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size;
|
||||||
|
}
|
||||||
|
// Modify the max
|
||||||
|
if (curr_memory_size > max_memory_size) {
|
||||||
|
max_memory_size = curr_memory_size;
|
||||||
|
max_memory_state = curr_memory_state;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Mark those critical operators
|
||||||
|
MarkCriticalOpsAndEdges(max_memory_state);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CostGraph::ComputeOpsAndEdgesOutputCritical() {
|
||||||
|
// Two steps to do:
|
||||||
|
// 1. Compute a topological order of the costgraph
|
||||||
|
// 2. Determine and mark the operators (and necessary edges) that are critical
|
||||||
|
std::vector<OperatorInfoPtr> topo_order;
|
||||||
|
TopologyOrder(&topo_order);
|
||||||
|
std::reverse(std::begin(topo_order), std::end(topo_order));
|
||||||
|
|
||||||
|
if (DetermineCriticalOps(topo_order) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Determining critical operators failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
Status CostGraph::CalculateOpsMemoryCost() {
|
Status CostGraph::CalculateOpsMemoryCost() {
|
||||||
for (auto &op : ops_) {
|
for (auto &op : ops_) {
|
||||||
MS_EXCEPTION_IF_NULL(op);
|
MS_EXCEPTION_IF_NULL(op);
|
||||||
|
@ -1427,6 +1551,17 @@ Status CostGraph::CalculateOpsMemoryCost() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CostGraph::CalculateOpsMemoryCostForInference() {
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op);
|
||||||
|
if (op->CalculateMemoryCostForInference() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
Status CostGraph::CalculateEdgesMemoryCost() {
|
Status CostGraph::CalculateEdgesMemoryCost() {
|
||||||
for (auto &edge_pair : edges_) {
|
for (auto &edge_pair : edges_) {
|
||||||
const auto &edges = edge_pair.second;
|
const auto &edges = edge_pair.second;
|
||||||
|
@ -1440,6 +1575,19 @@ Status CostGraph::CalculateEdgesMemoryCost() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CostGraph::CalculateEdgesMemoryCostForInference() {
|
||||||
|
for (auto &edge_pair : edges_) {
|
||||||
|
const auto &edges = edge_pair.second;
|
||||||
|
for (auto &one_edge : edges) {
|
||||||
|
if (one_edge->CalculateMemoryCostForInference() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const {
|
OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const {
|
||||||
for (auto one_op : ops_) {
|
for (auto one_op : ops_) {
|
||||||
if (one_op->name().find(IDENTITY_INFO) != std::string::npos) {
|
if (one_op->name().find(IDENTITY_INFO) != std::string::npos) {
|
||||||
|
@ -1480,5 +1628,49 @@ Status CostGraph::CorrectOpsMemoryCost() {
|
||||||
}
|
}
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CostGraph::CalculateMemoryCost() {
|
||||||
|
if (RUN_PHASE == TRAINING_PHASE) {
|
||||||
|
// training phase
|
||||||
|
if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
|
||||||
|
// Calculate operators' memory usage
|
||||||
|
if (CalculateOpsMemoryCost() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
// Calculate edges' memory usage
|
||||||
|
if (CalculateEdgesMemoryCost() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
// Correct memory usage caused by TmpIdentity
|
||||||
|
if (CorrectOpsMemoryCost() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Computing operators' parameter_involved failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// inference phase
|
||||||
|
if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) {
|
||||||
|
// Calculate operators' memory usage
|
||||||
|
if (CalculateOpsMemoryCostForInference() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
// Calculate edges's memory usage
|
||||||
|
if (CalculateEdgesMemoryCostForInference() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Computing operators' critical flag failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -179,16 +179,24 @@ class CostGraph {
|
||||||
void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &,
|
void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &,
|
||||||
const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>,
|
const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>,
|
||||||
CostPtrList &, CostPtrList &, CostPtrList *);
|
CostPtrList &, CostPtrList &, CostPtrList *);
|
||||||
|
// Calculate memory cost for training phase or inference phase.
|
||||||
|
Status CalculateMemoryCost();
|
||||||
// When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
|
// When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
|
||||||
// the memory cost can be resused.
|
// the memory cost can be resused. This is used to calculate memory in the training phase.
|
||||||
Status CalculateOpsMemoryCost();
|
Status CalculateOpsMemoryCost();
|
||||||
// When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
|
// When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
|
||||||
// the memory cost can be resused.
|
// the memory cost can be reused. This is used to calculate memory in the training phase.
|
||||||
Status CalculateEdgesMemoryCost();
|
Status CalculateEdgesMemoryCost();
|
||||||
|
// Calculate memory cost of operators in the inference phase.
|
||||||
|
Status CalculateOpsMemoryCostForInference();
|
||||||
|
// Calculate memory cost of edges in the inference phase.
|
||||||
|
Status CalculateEdgesMemoryCostForInference();
|
||||||
Status ComputeOpsAndEdgesParameterInvolved();
|
Status ComputeOpsAndEdgesParameterInvolved();
|
||||||
|
// Compute for each operator whether the output is critical.
|
||||||
|
Status ComputeOpsAndEdgesOutputCritical();
|
||||||
|
|
||||||
std::vector<OperatorInfoPtr> GetOperators() const { return ops_; }
|
std::vector<OperatorInfoPtr> GetOperators() const { return ops_; }
|
||||||
size_t GetNumPairs() const { return edges_.size(); }
|
size_t GetNumEdges() const;
|
||||||
Status InitSelectedStrategy();
|
Status InitSelectedStrategy();
|
||||||
OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const;
|
OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const;
|
||||||
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
|
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
|
||||||
|
@ -208,6 +216,10 @@ class CostGraph {
|
||||||
const std::map<std::string, std::string> get_tuple_getitem_list() const { return tuple_getitem_list_; }
|
const std::map<std::string, std::string> get_tuple_getitem_list() const { return tuple_getitem_list_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void TopologyOrder(std::vector<OperatorInfoPtr> *);
|
||||||
|
void DFSForTopoOrder(const OperatorInfoPtr &, std::map<OperatorInfoPtr, bool> *, std::vector<OperatorInfoPtr> *);
|
||||||
|
Status DetermineCriticalOps(const std::vector<OperatorInfoPtr> &);
|
||||||
|
void MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &);
|
||||||
// Needed by rec_parser
|
// Needed by rec_parser
|
||||||
std::vector<std::vector<std::string>> inputs_tensor_name_list_;
|
std::vector<std::vector<std::string>> inputs_tensor_name_list_;
|
||||||
std::map<std::string, std::string> tuple_getitem_list_;
|
std::map<std::string, std::string> tuple_getitem_list_;
|
||||||
|
|
|
@ -37,6 +37,8 @@ void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t> &input_
|
||||||
outputs_type_lengths_ = output_lengths;
|
outputs_type_lengths_ = output_lengths;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OperatorCost::set_output_critical(int critical) { is_outputs_critical_ = critical; }
|
||||||
|
|
||||||
double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
|
double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
|
||||||
const std::vector<TensorInfo> &outputs) const {
|
const std::vector<TensorInfo> &outputs) const {
|
||||||
double result = 0.0;
|
double result = 0.0;
|
||||||
|
@ -63,6 +65,20 @@ double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double OperatorCost::GetMemoryCostForInference(const std::vector<TensorInfo> &,
|
||||||
|
const std::vector<TensorInfo> &outputs) const {
|
||||||
|
double result = 0.0;
|
||||||
|
if (is_outputs_critical_ == -1) {
|
||||||
|
MS_LOG(EXCEPTION) << "The critical flag is not set.";
|
||||||
|
}
|
||||||
|
if (is_outputs_critical_ == 1) {
|
||||||
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
|
result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// return the per device communication cost in the forward phase.
|
// return the per device communication cost in the forward phase.
|
||||||
double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
int32_t) const {
|
int32_t) const {
|
||||||
|
|
|
@ -70,6 +70,7 @@ class OperatorCost {
|
||||||
void set_is_parameter(const std::vector<bool> &is_parameter);
|
void set_is_parameter(const std::vector<bool> &is_parameter);
|
||||||
void set_is_parameter_involve(const std::vector<bool> &);
|
void set_is_parameter_involve(const std::vector<bool> &);
|
||||||
void set_output_parameter_involve(int);
|
void set_output_parameter_involve(int);
|
||||||
|
void set_output_critical(int);
|
||||||
void SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, const std::vector<size_t> &output_lengths);
|
void SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, const std::vector<size_t> &output_lengths);
|
||||||
std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; }
|
std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; }
|
||||||
std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; }
|
std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; }
|
||||||
|
@ -92,6 +93,8 @@ class OperatorCost {
|
||||||
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled),
|
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled),
|
||||||
// plus necessary inputs.
|
// plus necessary inputs.
|
||||||
virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
|
virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
|
||||||
|
// per device memory cost in a inference phase
|
||||||
|
double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
|
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
|
||||||
|
@ -106,6 +109,9 @@ class OperatorCost {
|
||||||
// for each input and output, the followings record the number of bytes of each element
|
// for each input and output, the followings record the number of bytes of each element
|
||||||
std::vector<size_t> inputs_type_lengths_;
|
std::vector<size_t> inputs_type_lengths_;
|
||||||
std::vector<size_t> outputs_type_lengths_;
|
std::vector<size_t> outputs_type_lengths_;
|
||||||
|
// Whether the output is critical, which means that this output is included in calculating peak memory cost
|
||||||
|
// in the inference phase.
|
||||||
|
int is_outputs_critical_ = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
using OperatorCostPtr = std::shared_ptr<OperatorCost>;
|
using OperatorCostPtr = std::shared_ptr<OperatorCost>;
|
||||||
|
|
|
@ -1119,6 +1119,21 @@ Status OperatorInfo::CalculateMemoryCost() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status OperatorInfo::CalculateMemoryCostForInference() {
|
||||||
|
// First, set the 'is_outputs_critical_' flag into OperatorCost.
|
||||||
|
if (is_output_critical_ == -1) {
|
||||||
|
MS_LOG(EXCEPTION) << "The critical flag is not set.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
operator_cost()->set_output_critical(is_output_critical_);
|
||||||
|
// Set the memory cost in the 'strategy_cost_'
|
||||||
|
for (auto &swc : strategy_cost_) {
|
||||||
|
auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr);
|
||||||
|
swc->cost_list[0]->memory_with_reuse_ = mem_cost;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
Status OperatorInfo::CorrectMemoryCost(size_t input_index) {
|
Status OperatorInfo::CorrectMemoryCost(size_t input_index) {
|
||||||
for (auto &swc : strategy_cost_) {
|
for (auto &swc : strategy_cost_) {
|
||||||
double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) *
|
double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) *
|
||||||
|
@ -1230,6 +1245,25 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t> &inpu
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double OperatorInfo::GetOutputsTotalSize() {
|
||||||
|
if (is_calculated_outputs_size_) {
|
||||||
|
return outputs_total_size_;
|
||||||
|
}
|
||||||
|
if (outputs_type_lengths_.size() != outputs_shape_.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Output_lengths: " << outputs_type_lengths_.size()
|
||||||
|
<< " do not have the same number of outputs shape: " << outputs_shape_.size();
|
||||||
|
}
|
||||||
|
double sum = 0.0;
|
||||||
|
for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) {
|
||||||
|
auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast<double>(1.0),
|
||||||
|
std::multiplies<double>());
|
||||||
|
sum += size * static_cast<double>(outputs_type_lengths_[i]);
|
||||||
|
}
|
||||||
|
is_calculated_outputs_size_ = true;
|
||||||
|
outputs_total_size_ = sum;
|
||||||
|
return outputs_total_size_;
|
||||||
|
}
|
||||||
|
|
||||||
Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) {
|
Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) {
|
||||||
if (outputs_type.size() != outputs_shape_.size()) {
|
if (outputs_type.size() != outputs_shape_.size()) {
|
||||||
MS_LOG(ERROR) << "Outputs type: " << outputs_type.size()
|
MS_LOG(ERROR) << "Outputs type: " << outputs_type.size()
|
||||||
|
|
|
@ -72,6 +72,7 @@ class OperatorInfo {
|
||||||
Status set_is_parameter(const std::vector<bool> &is_parameter);
|
Status set_is_parameter(const std::vector<bool> &is_parameter);
|
||||||
Status SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
|
Status SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
|
||||||
const std::vector<size_t> &output_lengths);
|
const std::vector<size_t> &output_lengths);
|
||||||
|
double GetOutputsTotalSize();
|
||||||
// Set outputs dtype.
|
// Set outputs dtype.
|
||||||
// If only one output, outputs_type.size() is 1.
|
// If only one output, outputs_type.size() is 1.
|
||||||
// If output is tuple, outputs_type.size() is greater than 1.
|
// If output is tuple, outputs_type.size() is greater than 1.
|
||||||
|
@ -96,9 +97,13 @@ class OperatorInfo {
|
||||||
// is checked
|
// is checked
|
||||||
Status SetCostUnderStrategyBase(const StrategyPtr &strategy);
|
Status SetCostUnderStrategyBase(const StrategyPtr &strategy);
|
||||||
std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
|
std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
|
||||||
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
|
// In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving
|
||||||
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
|
// WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory
|
||||||
|
// at the end of forward phase.
|
||||||
Status CalculateMemoryCost();
|
Status CalculateMemoryCost();
|
||||||
|
// In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated
|
||||||
|
// by the output
|
||||||
|
Status CalculateMemoryCostForInference();
|
||||||
int ComputeOpAndPrevEdgeParameterInvolved();
|
int ComputeOpAndPrevEdgeParameterInvolved();
|
||||||
|
|
||||||
ForwardOp forward_op() const { return forward_op_; }
|
ForwardOp forward_op() const { return forward_op_; }
|
||||||
|
@ -147,6 +152,9 @@ class OperatorInfo {
|
||||||
// multiple times. This method is to correct this, and makes the cost is calulated only once.
|
// multiple times. This method is to correct this, and makes the cost is calulated only once.
|
||||||
Status CorrectMemoryCost(size_t input_index);
|
Status CorrectMemoryCost(size_t input_index);
|
||||||
int is_output_parameter_involve() const { return is_output_parameter_involve_; }
|
int is_output_parameter_involve() const { return is_output_parameter_involve_; }
|
||||||
|
int is_output_critical() const { return is_output_critical_; }
|
||||||
|
void mark_output_critical() { is_output_critical_ = 1; }
|
||||||
|
void mark_output_not_critical() { is_output_critical_ = 0; }
|
||||||
int used_devices() const { return used_devices_; }
|
int used_devices() const { return used_devices_; }
|
||||||
// needed by rec_parser
|
// needed by rec_parser
|
||||||
void set_type(const std::string &type) { type_ = type; }
|
void set_type(const std::string &type) { type_ = type; }
|
||||||
|
@ -220,7 +228,16 @@ class OperatorInfo {
|
||||||
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
|
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
|
||||||
// pre-operator that has parameters as input.
|
// pre-operator that has parameters as input.
|
||||||
std::vector<bool> is_parameter_involve_;
|
std::vector<bool> is_parameter_involve_;
|
||||||
int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
|
// If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating
|
||||||
|
// peak memory cost in the training phase.
|
||||||
|
// -1: unset; 0: not parameter_involved; 1: parameter_involved
|
||||||
|
int is_output_parameter_involve_ = -1;
|
||||||
|
// Whether this output is critical, which means that this output is included in calculating peak memory cost
|
||||||
|
// in the inference phase.
|
||||||
|
// -1 : unset; 0: not critical; 1: critical
|
||||||
|
int is_output_critical_ = -1;
|
||||||
|
double outputs_total_size_ = 0.0;
|
||||||
|
bool is_calculated_outputs_size_ = false;
|
||||||
// for each input and output, the followings record the number of bytes of each element
|
// for each input and output, the followings record the number of bytes of each element
|
||||||
std::vector<size_t> inputs_type_lengths_;
|
std::vector<size_t> inputs_type_lengths_;
|
||||||
std::vector<size_t> outputs_type_lengths_;
|
std::vector<size_t> outputs_type_lengths_;
|
||||||
|
|
|
@ -1055,6 +1055,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
|
||||||
// Step 1: Traverse the ANF graph, and create NODEs for costgraph:
|
// Step 1: Traverse the ANF graph, and create NODEs for costgraph:
|
||||||
// create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
|
// create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
|
||||||
// for each OperatorInfo;
|
// for each OperatorInfo;
|
||||||
|
// Step 1.1: Deal with 'Reshape':
|
||||||
|
// For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
|
||||||
|
// layout as its output layout.
|
||||||
// Step 2: Traverse the ANF graph, and create EDGES for costgraph:
|
// Step 2: Traverse the ANF graph, and create EDGES for costgraph:
|
||||||
// create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
|
// create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
|
||||||
// for each edge, based on the strategies of two OperatorInfos;
|
// for each edge, based on the strategies of two OperatorInfos;
|
||||||
|
@ -1062,7 +1065,8 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
|
||||||
// taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
|
// taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
|
||||||
// operator for this Parameter, and add an edge for the use of this Parameter by each
|
// operator for this Parameter, and add an edge for the use of this Parameter by each
|
||||||
// subsequent operator;
|
// subsequent operator;
|
||||||
// Step 3.1: Calculate memory usage
|
// Step 3.1: Calculate memory usage:
|
||||||
|
// note the memory usage calculation is different in training phase and inference phase.
|
||||||
// Step 4: Run the Dynamic Programming algorithm:
|
// Step 4: Run the Dynamic Programming algorithm:
|
||||||
// in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
|
// in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
|
||||||
// cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
|
// cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
|
||||||
|
@ -1087,35 +1091,21 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
|
||||||
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
|
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// reshape operator needs the next node's input_layout as its output_layout.
|
// Step 1.1
|
||||||
// and needs the previous node's output_layout as its input_layout.
|
|
||||||
ReshapeCostCompute(all_nodes);
|
ReshapeCostCompute(all_nodes);
|
||||||
// Step 2
|
// Step 2
|
||||||
ConstructCostGraphEdges(all_nodes);
|
ConstructCostGraphEdges(all_nodes);
|
||||||
MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
|
MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
|
||||||
<< " operators, and " << entire_costgraph->GetNumPairs() << " edges.",
|
<< " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
|
||||||
|
|
||||||
// Step 3: Augment the costgraph.
|
// Step 3: Augment the costgraph.
|
||||||
AugmentCostGraph(all_nodes);
|
AugmentCostGraph(all_nodes);
|
||||||
MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size()
|
MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size()
|
||||||
<< " operators, and " << entire_costgraph->GetNumPairs() << " edges.";
|
<< " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
|
||||||
|
|
||||||
// Step 3.1: Calculate the memory usage
|
// Step 3.1: Calculate the memory usage
|
||||||
if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
|
if (entire_costgraph->CalculateMemoryCost() != SUCCESS) {
|
||||||
// Calculate operators' memory usage
|
MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
|
||||||
if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) {
|
|
||||||
MS_LOG(EXCEPTION) << "Calculating operators' cost for memory cost failed.";
|
|
||||||
}
|
|
||||||
// Calculate edges' memory usage
|
|
||||||
if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) {
|
|
||||||
MS_LOG(EXCEPTION) << "Calculating edges' cost for memory cost failed.";
|
|
||||||
}
|
|
||||||
// Correct memory usage caused by TmpIdentity
|
|
||||||
if (entire_costgraph->CorrectOpsMemoryCost() != SUCCESS) {
|
|
||||||
MS_LOG(EXCEPTION) << "Correcting operators' cost for memory cost failed.";
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "Computing operators' parameter_involved failed.";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 4: run DP algorithm on the costgraph.
|
// Step 4: run DP algorithm on the costgraph.
|
||||||
|
|
|
@ -32,5 +32,6 @@ def test_inference_phase():
|
||||||
net_with_loss = WithLossCell(net, loss)
|
net_with_loss = WithLossCell(net, loss)
|
||||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||||
train_network.set_train()
|
train_network.set_train()
|
||||||
|
train_network.set_auto_parallel()
|
||||||
|
|
||||||
output = train_network(predict, label)
|
output = train_network(predict, label)
|
Loading…
Reference in New Issue