diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h index adae0688d61..8f0a79d0be2 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h @@ -53,7 +53,7 @@ struct Cost { communication_redis_backward_ = 0.0; communication_forward_ = 0.0; } - // 'memory_with_reuse_' calculates the peak memory usage in a training phase + // 'memory_with_reuse_' calculates the peak memory usage in a training (or inference) phase double memory_with_reuse_; // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated // by ONLY forward phase diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc index 9189689f525..16393eb1736 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc @@ -300,5 +300,20 @@ Status Edge::CalculateMemoryCost() { return SUCCESS; } + +Status Edge::CalculateMemoryCostForInference() { + // Currently, memory cost is NOT calculated for redistribution + if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) { + MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_; + return FAILED; + } + for (auto &cost_kv : cost_map_) { + auto &cost_v = cost_kv.second; + if (!cost_v.empty()) { + cost_v[0]->memory_with_reuse_ = 0; + } + } + return SUCCESS; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h index e760c24c345..2a5ed3b2a40 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h @@ -131,9 +131,13 @@ class Edge { void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; } const CostPtr &selected_cost() const { return selected_cost_; } void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } - // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input - // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. + // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving + // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory + // at the end of forward phase. Status CalculateMemoryCost(); + // In the inference phase, + Status CalculateMemoryCostForInference(); + void mark_output_critical() { is_output_critical_ = 1; } private: std::string edge_name_; @@ -156,7 +160,11 @@ class Edge { // If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor. bool is_identity_edge; CostPtr selected_cost_; + // In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator + // is parameter-involved int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved + // In the inference phase, this is used to mark whether the output of the previous operator is critical. + int is_output_critical_ = 0; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index 9930cf7046d..5419ab69c99 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -369,7 +369,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ << ", communication_cost_: " << ret->communication_cost_ << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; - MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum; + MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; for (size_t i = 1; i < after_mem_filter.size(); ++i) { MS_EXCEPTION_IF_NULL(after_mem_filter[i]); MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ @@ -422,7 +422,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ << ", communication_cost_: " << ret->communication_cost_ << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; - MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum; + MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; for (size_t i = 1; i < after_mem_filter.size(); ++i) { MS_EXCEPTION_IF_NULL(after_mem_filter[i]); MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ @@ -1351,6 +1351,14 @@ std::vector> CostGraph::EliminationStar(const OperatorInfo return succ_edges; } +size_t CostGraph::GetNumEdges() const { + size_t sum = 0; + for (const auto &kv : edges_) { + auto &edges = kv.second; + sum += edges.size(); + } + return sum; +} Status CostGraph::InitSelectedStrategy() { for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); @@ -1416,6 +1424,122 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { return SUCCESS; } +void CostGraph::DFSForTopoOrder(const OperatorInfoPtr ¤t_op, std::map *visited, + std::vector *topo_order) { + MS_EXCEPTION_IF_NULL(current_op); + MS_EXCEPTION_IF_NULL(visited); + MS_EXCEPTION_IF_NULL(topo_order); + + visited->at(current_op) = true; + for (const auto &s_edge : current_op->succ_edges()) { + if (!visited->at(s_edge->next_operator())) { + DFSForTopoOrder(s_edge->next_operator(), visited, topo_order); + } + } + topo_order->push_back(current_op); +} + +// Compute a topological order of the costgraph +void CostGraph::TopologyOrder(std::vector *topo_order) { + std::map visited; + for (auto &op : ops_) { + visited[op] = false; + } + + for (auto &op : ops_) { + if (!visited[op]) { + DFSForTopoOrder(op, &visited, topo_order); + } + } +} +void CostGraph::MarkCriticalOpsAndEdges(const std::map &candidate_ops) { + for (auto &op : ops_) { + auto search = candidate_ops.find(op); + if (search != candidate_ops.end()) { + // Mark the critical operators + op->mark_output_critical(); + // Mark the successive edges + for (auto &s_edge : op->succ_edges()) { + s_edge->mark_output_critical(); + } + } else { + op->mark_output_not_critical(); + } + } +} + +Status CostGraph::DetermineCriticalOps(const std::vector &topo_order) { + if (topo_order.size() == 0) { + MS_LOG(ERROR) << "0 operator in costgraph."; + return FAILED; + } + auto &first_op = topo_order[0]; + if (first_op->prev_edges().size() > 0) { + MS_LOG(ERROR) << "The first operator in the first of topological order of " + "costgraph should have 0 incoming edge, but has " + << first_op->prev_edges() << "edges."; + return FAILED; + } + // The 'curr_memory_state' records , where remaining_output_cnt is the number + // of the output of OperatorInfo that currently has not been used + std::map curr_memory_state; + (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size()))); + std::map max_memory_state = curr_memory_state; + // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has + // not been used + double curr_memory_size = first_op->GetOutputsTotalSize(); + double max_memory_size = curr_memory_size; + + for (size_t finished = 1; finished < topo_order.size(); ++finished) { + // Produce + (void)curr_memory_state.emplace( + std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size()))); + curr_memory_size += topo_order[finished]->GetOutputsTotalSize(); + // Consume + for (const auto &prev_edge : topo_order[finished]->prev_edges()) { + const auto &prev_op = prev_edge->prev_operator(); + curr_memory_state[prev_op]--; + } + for (const auto &prev_edge : topo_order[finished]->prev_edges()) { + const auto &prev_op = prev_edge->prev_operator(); + if (curr_memory_state[prev_op] < 0) { + MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op]; + return FAILED; + } else if (curr_memory_state[prev_op] == 0) { + curr_memory_state.erase(prev_op); + curr_memory_size -= prev_op->GetOutputsTotalSize(); + } + } + + if (curr_memory_size < 0) { + MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size; + } + // Modify the max + if (curr_memory_size > max_memory_size) { + max_memory_size = curr_memory_size; + max_memory_state = curr_memory_state; + } + } + // Mark those critical operators + MarkCriticalOpsAndEdges(max_memory_state); + return SUCCESS; +} + +Status CostGraph::ComputeOpsAndEdgesOutputCritical() { + // Two steps to do: + // 1. Compute a topological order of the costgraph + // 2. Determine and mark the operators (and necessary edges) that are critical + std::vector topo_order; + TopologyOrder(&topo_order); + std::reverse(std::begin(topo_order), std::end(topo_order)); + + if (DetermineCriticalOps(topo_order) != SUCCESS) { + MS_LOG(ERROR) << "Determining critical operators failed."; + return FAILED; + } + return SUCCESS; +} + Status CostGraph::CalculateOpsMemoryCost() { for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); @@ -1427,6 +1551,17 @@ Status CostGraph::CalculateOpsMemoryCost() { return SUCCESS; } +Status CostGraph::CalculateOpsMemoryCostForInference() { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + if (op->CalculateMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; + return FAILED; + } + } + return SUCCESS; +} + Status CostGraph::CalculateEdgesMemoryCost() { for (auto &edge_pair : edges_) { const auto &edges = edge_pair.second; @@ -1440,6 +1575,19 @@ Status CostGraph::CalculateEdgesMemoryCost() { return SUCCESS; } +Status CostGraph::CalculateEdgesMemoryCostForInference() { + for (auto &edge_pair : edges_) { + const auto &edges = edge_pair.second; + for (auto &one_edge : edges) { + if (one_edge->CalculateMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; + return FAILED; + } + } + } + return SUCCESS; +} + OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const { for (auto one_op : ops_) { if (one_op->name().find(IDENTITY_INFO) != std::string::npos) { @@ -1480,5 +1628,49 @@ Status CostGraph::CorrectOpsMemoryCost() { } return SUCCESS; } + +Status CostGraph::CalculateMemoryCost() { + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { + // Calculate operators' memory usage + if (CalculateOpsMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed."; + return FAILED; + } + // Calculate edges' memory usage + if (CalculateEdgesMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed."; + return FAILED; + } + // Correct memory usage caused by TmpIdentity + if (CorrectOpsMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed."; + return FAILED; + } + } else { + MS_LOG(ERROR) << "Computing operators' parameter_involved failed."; + return FAILED; + } + } else { + // inference phase + if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) { + // Calculate operators' memory usage + if (CalculateOpsMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; + return FAILED; + } + // Calculate edges's memory usage + if (CalculateEdgesMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; + return FAILED; + } + } else { + MS_LOG(ERROR) << "Computing operators' critical flag failed."; + return FAILED; + } + } + return SUCCESS; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index 6c58ac79573..3b8b389d813 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -179,16 +179,24 @@ class CostGraph { void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, const StrategyPtr &, const CostPtrList &, std::vector, CostPtrList &, CostPtrList &, CostPtrList *); + // Calculate memory cost for training phase or inference phase. + Status CalculateMemoryCost(); // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then - // the memory cost can be resused. + // the memory cost can be resused. This is used to calculate memory in the training phase. Status CalculateOpsMemoryCost(); // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then - // the memory cost can be resused. + // the memory cost can be reused. This is used to calculate memory in the training phase. Status CalculateEdgesMemoryCost(); + // Calculate memory cost of operators in the inference phase. + Status CalculateOpsMemoryCostForInference(); + // Calculate memory cost of edges in the inference phase. + Status CalculateEdgesMemoryCostForInference(); Status ComputeOpsAndEdgesParameterInvolved(); + // Compute for each operator whether the output is critical. + Status ComputeOpsAndEdgesOutputCritical(); std::vector GetOperators() const { return ops_; } - size_t GetNumPairs() const { return edges_.size(); } + size_t GetNumEdges() const; Status InitSelectedStrategy(); OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only @@ -208,6 +216,10 @@ class CostGraph { const std::map get_tuple_getitem_list() const { return tuple_getitem_list_; } private: + void TopologyOrder(std::vector *); + void DFSForTopoOrder(const OperatorInfoPtr &, std::map *, std::vector *); + Status DetermineCriticalOps(const std::vector &); + void MarkCriticalOpsAndEdges(const std::map &); // Needed by rec_parser std::vector> inputs_tensor_name_list_; std::map tuple_getitem_list_; diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc index 8f47d30273d..cbf5bc40c8f 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc @@ -37,6 +37,8 @@ void OperatorCost::SetInputAndOutputTypeLength(const std::vector &input_ outputs_type_lengths_ = output_lengths; } +void OperatorCost::set_output_critical(int critical) { is_outputs_critical_ = critical; } + double OperatorCost::GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const { double result = 0.0; @@ -63,6 +65,20 @@ double OperatorCost::GetMemoryCost(const std::vector &inputs, return result; } +double OperatorCost::GetMemoryCostForInference(const std::vector &, + const std::vector &outputs) const { + double result = 0.0; + if (is_outputs_critical_ == -1) { + MS_LOG(EXCEPTION) << "The critical flag is not set."; + } + if (is_outputs_critical_ == 1) { + for (size_t i = 0; i < outputs.size(); ++i) { + result += ListProduct(outputs[i].slice_shape()) * static_cast(outputs_type_lengths_[i]); + } + } + return result; +} + // return the per device communication cost in the forward phase. double MatMulCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t) const { diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h index 4b482826129..a08a4dbb131 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h @@ -70,6 +70,7 @@ class OperatorCost { void set_is_parameter(const std::vector &is_parameter); void set_is_parameter_involve(const std::vector &); void set_output_parameter_involve(int); + void set_output_critical(int); void SetInputAndOutputTypeLength(const std::vector &input_lengths, const std::vector &output_lengths); std::vector inputs_type_lengths() const { return inputs_type_lengths_; } std::vector outputs_type_lengths() const { return outputs_type_lengths_; } @@ -92,6 +93,8 @@ class OperatorCost { // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), // plus necessary inputs. virtual double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const; + // per device memory cost in a inference phase + double GetMemoryCostForInference(const std::vector &, const std::vector &) const; protected: // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of @@ -106,6 +109,9 @@ class OperatorCost { // for each input and output, the followings record the number of bytes of each element std::vector inputs_type_lengths_; std::vector outputs_type_lengths_; + // Whether the output is critical, which means that this output is included in calculating peak memory cost + // in the inference phase. + int is_outputs_critical_ = -1; }; using OperatorCostPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc index b1db6cda4db..4a680a7c362 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.cc @@ -1119,6 +1119,21 @@ Status OperatorInfo::CalculateMemoryCost() { return SUCCESS; } +Status OperatorInfo::CalculateMemoryCostForInference() { + // First, set the 'is_outputs_critical_' flag into OperatorCost. + if (is_output_critical_ == -1) { + MS_LOG(EXCEPTION) << "The critical flag is not set."; + return FAILED; + } + operator_cost()->set_output_critical(is_output_critical_); + // Set the memory cost in the 'strategy_cost_' + for (auto &swc : strategy_cost_) { + auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr); + swc->cost_list[0]->memory_with_reuse_ = mem_cost; + } + return SUCCESS; +} + Status OperatorInfo::CorrectMemoryCost(size_t input_index) { for (auto &swc : strategy_cost_) { double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * @@ -1230,6 +1245,25 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector &inpu return SUCCESS; } +double OperatorInfo::GetOutputsTotalSize() { + if (is_calculated_outputs_size_) { + return outputs_total_size_; + } + if (outputs_type_lengths_.size() != outputs_shape_.size()) { + MS_LOG(EXCEPTION) << "Output_lengths: " << outputs_type_lengths_.size() + << " do not have the same number of outputs shape: " << outputs_shape_.size(); + } + double sum = 0.0; + for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) { + auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast(1.0), + std::multiplies()); + sum += size * static_cast(outputs_type_lengths_[i]); + } + is_calculated_outputs_size_ = true; + outputs_total_size_ = sum; + return outputs_total_size_; +} + Status OperatorInfo::set_outputs_type(const std::vector &outputs_type) { if (outputs_type.size() != outputs_shape_.size()) { MS_LOG(ERROR) << "Outputs type: " << outputs_type.size() diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index 8855e70b41f..c44f1f09fae 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -72,6 +72,7 @@ class OperatorInfo { Status set_is_parameter(const std::vector &is_parameter); Status SetInputAndOutputTypeLength(const std::vector &input_lengths, const std::vector &output_lengths); + double GetOutputsTotalSize(); // Set outputs dtype. // If only one output, outputs_type.size() is 1. // If output is tuple, outputs_type.size() is greater than 1. @@ -96,9 +97,13 @@ class OperatorInfo { // is checked Status SetCostUnderStrategyBase(const StrategyPtr &strategy); std::vector> GetStrategyCost() { return strategy_cost_; } - // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input - // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. + // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving + // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory + // at the end of forward phase. Status CalculateMemoryCost(); + // In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated + // by the output + Status CalculateMemoryCostForInference(); int ComputeOpAndPrevEdgeParameterInvolved(); ForwardOp forward_op() const { return forward_op_; } @@ -147,6 +152,9 @@ class OperatorInfo { // multiple times. This method is to correct this, and makes the cost is calulated only once. Status CorrectMemoryCost(size_t input_index); int is_output_parameter_involve() const { return is_output_parameter_involve_; } + int is_output_critical() const { return is_output_critical_; } + void mark_output_critical() { is_output_critical_ = 1; } + void mark_output_not_critical() { is_output_critical_ = 0; } int used_devices() const { return used_devices_; } // needed by rec_parser void set_type(const std::string &type) { type_ = type; } @@ -220,7 +228,16 @@ class OperatorInfo { // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of // pre-operator that has parameters as input. std::vector is_parameter_involve_; - int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved + // If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating + // peak memory cost in the training phase. + // -1: unset; 0: not parameter_involved; 1: parameter_involved + int is_output_parameter_involve_ = -1; + // Whether this output is critical, which means that this output is included in calculating peak memory cost + // in the inference phase. + // -1 : unset; 0: not critical; 1: critical + int is_output_critical_ = -1; + double outputs_total_size_ = 0.0; + bool is_calculated_outputs_size_ = false; // for each input and output, the followings record the number of bytes of each element std::vector inputs_type_lengths_; std::vector outputs_type_lengths_; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 10794dba274..7a3d5a38fb7 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -1055,6 +1055,9 @@ Status ParallelStrategySearch(const std::vector &all_nodes, const Fu // Step 1: Traverse the ANF graph, and create NODEs for costgraph: // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies // for each OperatorInfo; + // Step 1.1: Deal with 'Reshape': + // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's + // layout as its output layout. // Step 2: Traverse the ANF graph, and create EDGES for costgraph: // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies // for each edge, based on the strategies of two OperatorInfos; @@ -1062,7 +1065,8 @@ Status ParallelStrategySearch(const std::vector &all_nodes, const Fu // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity // operator for this Parameter, and add an edge for the use of this Parameter by each // subsequent operator; - // Step 3.1: Calculate memory usage + // Step 3.1: Calculate memory usage: + // note the memory usage calculation is different in training phase and inference phase. // Step 4: Run the Dynamic Programming algorithm: // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input @@ -1087,35 +1091,21 @@ Status ParallelStrategySearch(const std::vector &all_nodes, const Fu MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; } } - // reshape operator needs the next node's input_layout as its output_layout. - // and needs the previous node's output_layout as its input_layout. + // Step 1.1 ReshapeCostCompute(all_nodes); // Step 2 ConstructCostGraphEdges(all_nodes); MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() - << " operators, and " << entire_costgraph->GetNumPairs() << " edges.", + << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; - // Step 3: Augment the costgraph. - AugmentCostGraph(all_nodes); + // Step 3: Augment the costgraph. + AugmentCostGraph(all_nodes); MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() - << " operators, and " << entire_costgraph->GetNumPairs() << " edges."; + << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; // Step 3.1: Calculate the memory usage - if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { - // Calculate operators' memory usage - if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Calculating operators' cost for memory cost failed."; - } - // Calculate edges' memory usage - if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Calculating edges' cost for memory cost failed."; - } - // Correct memory usage caused by TmpIdentity - if (entire_costgraph->CorrectOpsMemoryCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Correcting operators' cost for memory cost failed."; - } - } else { - MS_LOG(EXCEPTION) << "Computing operators' parameter_involved failed."; + if (entire_costgraph->CalculateMemoryCost() != SUCCESS) { + MS_LOG(EXCEPTION) << "Calculating memory cost failed."; } // Step 4: run DP algorithm on the costgraph. diff --git a/tests/ut/python/parallel/test_auto_parallel_inference.py b/tests/ut/python/parallel/test_auto_parallel_inference.py index 3aad5f62c6c..0d8cdabb6bf 100644 --- a/tests/ut/python/parallel/test_auto_parallel_inference.py +++ b/tests/ut/python/parallel/test_auto_parallel_inference.py @@ -32,5 +32,6 @@ def test_inference_phase(): net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) train_network.set_train() + train_network.set_auto_parallel() output = train_network(predict, label) \ No newline at end of file