forked from mindspore-Ecosystem/mindspore
!41484 fix code check for parallel code
Merge pull request !41484 from yangzhenzhang/fix-code-check-for-parallel
This commit is contained in:
commit
3a520efaf0
|
@ -514,13 +514,13 @@ std::shared_ptr<Strategies> BatchMatMulInfo::GenerateBatchStrategies() {
|
|||
|
||||
Status MatMulBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
// if the transpose_b is false:
|
||||
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, C, D), (A, B, D, 1))
|
||||
// in_strategy: ((), (A, B, D, E)), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, 1, D), (A, B, D, E))
|
||||
// if the transpose_b is true:
|
||||
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, C, D), (A, B, 1, D))
|
||||
// in_strategy: ((), (A, B, E, D)), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, 1, D), (A, B, E, D))
|
||||
Shapes MatMulBase::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
// if the transpose_b is false:
|
||||
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, C, D), (A, B, D, 1))
|
||||
// in_strategy: ((), (A, B, D, E)), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, 1, D), (A, B, D, E))
|
||||
// if the transpose_b is true:
|
||||
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, C, D), (A, B, 1, D))
|
||||
// in_strategy: ((), (A, B, E, D)), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, 1, D), (A, B, E, D))
|
||||
if (in_strategy.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in strategy must be 2, but got " << in_strategy.size();
|
||||
}
|
||||
|
|
|
@ -122,7 +122,7 @@ Status OperatorInfo::CheckOutputStrategy(const StrategyPtr &out_strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status OperatorInfo::CheckStrategyBase(const Shapes &stra, const Shapes &inputs_shape) {
|
||||
Status OperatorInfo::CheckStrategyByVector(const Shapes &stra, const Shapes &inputs_shape) {
|
||||
size_t strategy_size = stra.size();
|
||||
size_t inputs_shape_size = inputs_shape.size();
|
||||
if (strategy_size != inputs_shape_size) {
|
||||
|
@ -180,7 +180,7 @@ Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shape
|
|||
}
|
||||
|
||||
Strategies stra = strategy->GetInputDim();
|
||||
return CheckStrategyBase(stra, inputs_shape);
|
||||
return CheckStrategyByVector(stra, inputs_shape);
|
||||
}
|
||||
|
||||
void OperatorInfo::ResetQueueMember() {
|
||||
|
@ -2113,7 +2113,7 @@ float OperatorInfo::GetFloatAttr(const std::string &attr_name) {
|
|||
|
||||
// in_strategy: ((A, B, C, D), ()), return: ((A, B, C, D), (A, B, C, D))
|
||||
// in_strategy: ((), (A, B, C, D)), return: ((A, B, C, D), (A, B, C, D))
|
||||
Shapes OperatorInfo::InferStrategySameMode(const Shapes &in_strategy) {
|
||||
Shapes OperatorInfo::InferStrategySameMode(const Shapes &in_strategy) const {
|
||||
Shape value;
|
||||
for (auto &ele : in_strategy) {
|
||||
if (!ele.empty()) {
|
||||
|
|
|
@ -237,10 +237,10 @@ class OperatorInfo {
|
|||
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
|
||||
virtual Shapes InferStrategyIndividualMode(const Shapes &in_strategy);
|
||||
Shapes GenerateFullStrategyBase(const Shapes &in_strategy);
|
||||
Shapes InferStrategySameMode(const Shapes &in_strategy);
|
||||
Shapes InferStrategySameMode(const Shapes &in_strategy) const;
|
||||
Shapes InferStrategyBroadcastMode(const Shapes &in_strategy);
|
||||
Shapes InferStrategyIndependentMode(const Shapes &in_strategy);
|
||||
Status CheckStrategyBase(const Shapes &strategy, const Shapes &inputs_shape);
|
||||
Status CheckStrategyByVector(const Shapes &strategy, const Shapes &inputs_shape);
|
||||
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
|
||||
void SetRepeatedCalcDevMatrix();
|
||||
void ResetTensorMapIfRepeatedCalc();
|
||||
|
|
|
@ -84,7 +84,7 @@ void StridedSliceInfo::ComputeBeginMask() {
|
|||
}
|
||||
}
|
||||
|
||||
if (begin_mask_) {
|
||||
if (begin_mask_ != 0) {
|
||||
MS_LOG(INFO) << name_ << ": The begin is modified to " << begin_;
|
||||
}
|
||||
}
|
||||
|
@ -98,7 +98,7 @@ void StridedSliceInfo::ComputeEndMask() {
|
|||
}
|
||||
}
|
||||
|
||||
if (end_mask_) {
|
||||
if (end_mask_ != 0) {
|
||||
MS_LOG(INFO) << name_ << ": The end is modified to " << end_;
|
||||
}
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ void StridedSliceInfo::ComputeNewAxisMask() {
|
|||
(void)input_shape_in_process_.insert(input_shape_in_process_.end(), inputs_shape_[0].begin() + count,
|
||||
inputs_shape_[0].end());
|
||||
|
||||
if (new_axis_mask_) {
|
||||
if (new_axis_mask_ != 0) {
|
||||
MS_LOG(INFO) << name_ << ": The begin is modified to " << begin_ << ", the end is modified to " << end_
|
||||
<< ", the strides is modified to " << strides_ << ", the input shape in process is "
|
||||
<< input_shape_in_process_;
|
||||
|
|
|
@ -331,7 +331,7 @@ Shapes UnsortedSegmentOpInfo::InferStrategyIndividualMode(const Shapes &in_strat
|
|||
<< inputs_shape_[0].size() << ", the size of inputs_shape[1] is " << inputs_shape_[1].size();
|
||||
}
|
||||
size_t diff_len = inputs_shape_[0].size() - inputs_shape_[1].size();
|
||||
tmp_strategy.insert(tmp_strategy.end(), diff_len, 1);
|
||||
(void)tmp_strategy.insert(tmp_strategy.end(), diff_len, 1);
|
||||
return Shapes({tmp_strategy, in_strategy[1]});
|
||||
}
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";
|
||||
|
|
|
@ -133,7 +133,7 @@ static bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies,
|
|||
return true;
|
||||
}
|
||||
|
||||
static void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_strategy, const int64_t &device_num) {
|
||||
static void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_strategy) {
|
||||
auto out_strategy_tuple = out_strategy->cast<ValueNodePtr>();
|
||||
bool need_default_strategy = false;
|
||||
size_t out_strategy_size = 0;
|
||||
|
@ -176,8 +176,7 @@ static void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &ou
|
|||
}
|
||||
}
|
||||
|
||||
static Shapes GenerateDefaultStrategyForParam(const CNodePtr &cnode, const std::vector<AnfNodePtr> ¶meters,
|
||||
const Shapes &input_strategy) {
|
||||
static Shapes GenerateDefaultStrategyForParam(const CNodePtr &cnode, const Shapes &input_strategy) {
|
||||
auto current_inputs = cnode->inputs();
|
||||
Shapes elements;
|
||||
for (size_t i = 1; i < current_inputs.size(); ++i) {
|
||||
|
@ -268,7 +267,7 @@ static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const A
|
|||
}
|
||||
}
|
||||
for (auto &cnode : concerned_nodes) {
|
||||
Shapes ret_strategy = GenerateDefaultStrategyForParam(cnode, parameters, input_strategy);
|
||||
Shapes ret_strategy = GenerateDefaultStrategyForParam(cnode, input_strategy);
|
||||
// Set in_strategy
|
||||
auto strategy = ShapesToValueTuplePtr(ret_strategy);
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
|
@ -437,7 +436,7 @@ static bool SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfN
|
|||
std::set_union(input_concerned_cnode.begin(), input_concerned_cnode.end(), parameter_concerned_cnode.begin(),
|
||||
parameter_concerned_cnode.end(), std::inserter(concerned_cnode, concerned_cnode.end()));
|
||||
CompleteConcernedCNodeStrategies(concerned_cnode);
|
||||
SetOutputLayout(func_graph, out_strategy, device_num); // Not in effect currently
|
||||
SetOutputLayout(func_graph, out_strategy); // Not in effect currently
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1011,14 +1011,14 @@ bool IsPynativeParallel() {
|
|||
return (execution_mode == kPynativeMode) && (parallel_mode == kSemiAutoParallel || parallel_mode == kAutoParallel);
|
||||
}
|
||||
|
||||
// compile graph order:
|
||||
// 1, ParallelParameterContextInitShape/Ckpt/Restore
|
||||
// 2, PynativeShard: find 'shard' node and set 'pynative_shard' flag for root graph
|
||||
// 3, PipelineSplit: insert virtual dataset
|
||||
// 4, StepAutoParallel
|
||||
// 5, StepParallel
|
||||
// if IsPynativeParallel() is true, it maybe has some graphs that we no care, so need to check 'pynative_shard' flag
|
||||
bool IsAutoParallelCareGraph(const FuncGraphPtr &func_graph) {
|
||||
// compile graph order:
|
||||
// 1, ParallelParameterContextInitShape/Ckpt/Restore
|
||||
// 2, PynativeShard: find 'shard' node and set 'pynative_shard' flag for root graph
|
||||
// 3, PipelineSplit: insert virtual dataset
|
||||
// 4, StepAutoParallel
|
||||
// 5, StepParallel
|
||||
// if IsPynativeParallel() is true, it maybe has some graphs that we no care, so need to check 'pynative_shard' flag
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (func_graph->has_flag(kSkipAutoParallelCompile)) {
|
||||
return false;
|
||||
|
|
Loading…
Reference in New Issue