!41484 fix code check for parallel code

Merge pull request !41484 from yangzhenzhang/fix-code-check-for-parallel
This commit is contained in:
i-robot 2022-09-06 10:18:10 +00:00 committed by Gitee
commit 3a520efaf0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 26 additions and 27 deletions

View File

@ -514,13 +514,13 @@ std::shared_ptr<Strategies> BatchMatMulInfo::GenerateBatchStrategies() {
Status MatMulBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
// if the transpose_b is false:
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, C, D), (A, B, D, 1))
// in_strategy: ((), (A, B, D, E)), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, 1, D), (A, B, D, E))
// if the transpose_b is true:
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, C, D), (A, B, 1, D))
// in_strategy: ((), (A, B, E, D)), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, 1, D), (A, B, E, D))
Shapes MatMulBase::InferStrategyIndividualMode(const Shapes &in_strategy) {
// if the transpose_b is false:
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, C, D), (A, B, D, 1))
// in_strategy: ((), (A, B, D, E)), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, 1, D), (A, B, D, E))
// if the transpose_b is true:
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, C, D), (A, B, 1, D))
// in_strategy: ((), (A, B, E, D)), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, 1, D), (A, B, E, D))
if (in_strategy.size() != 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of in strategy must be 2, but got " << in_strategy.size();
}

View File

@ -122,7 +122,7 @@ Status OperatorInfo::CheckOutputStrategy(const StrategyPtr &out_strategy) {
return SUCCESS;
}
Status OperatorInfo::CheckStrategyBase(const Shapes &stra, const Shapes &inputs_shape) {
Status OperatorInfo::CheckStrategyByVector(const Shapes &stra, const Shapes &inputs_shape) {
size_t strategy_size = stra.size();
size_t inputs_shape_size = inputs_shape.size();
if (strategy_size != inputs_shape_size) {
@ -180,7 +180,7 @@ Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shape
}
Strategies stra = strategy->GetInputDim();
return CheckStrategyBase(stra, inputs_shape);
return CheckStrategyByVector(stra, inputs_shape);
}
void OperatorInfo::ResetQueueMember() {
@ -2113,7 +2113,7 @@ float OperatorInfo::GetFloatAttr(const std::string &attr_name) {
// in_strategy: ((A, B, C, D), ()), return: ((A, B, C, D), (A, B, C, D))
// in_strategy: ((), (A, B, C, D)), return: ((A, B, C, D), (A, B, C, D))
Shapes OperatorInfo::InferStrategySameMode(const Shapes &in_strategy) {
Shapes OperatorInfo::InferStrategySameMode(const Shapes &in_strategy) const {
Shape value;
for (auto &ele : in_strategy) {
if (!ele.empty()) {

View File

@ -237,10 +237,10 @@ class OperatorInfo {
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
virtual Shapes InferStrategyIndividualMode(const Shapes &in_strategy);
Shapes GenerateFullStrategyBase(const Shapes &in_strategy);
Shapes InferStrategySameMode(const Shapes &in_strategy);
Shapes InferStrategySameMode(const Shapes &in_strategy) const;
Shapes InferStrategyBroadcastMode(const Shapes &in_strategy);
Shapes InferStrategyIndependentMode(const Shapes &in_strategy);
Status CheckStrategyBase(const Shapes &strategy, const Shapes &inputs_shape);
Status CheckStrategyByVector(const Shapes &strategy, const Shapes &inputs_shape);
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
void SetRepeatedCalcDevMatrix();
void ResetTensorMapIfRepeatedCalc();

View File

@ -84,7 +84,7 @@ void StridedSliceInfo::ComputeBeginMask() {
}
}
if (begin_mask_) {
if (begin_mask_ != 0) {
MS_LOG(INFO) << name_ << ": The begin is modified to " << begin_;
}
}
@ -98,7 +98,7 @@ void StridedSliceInfo::ComputeEndMask() {
}
}
if (end_mask_) {
if (end_mask_ != 0) {
MS_LOG(INFO) << name_ << ": The end is modified to " << end_;
}
}
@ -145,7 +145,7 @@ void StridedSliceInfo::ComputeNewAxisMask() {
(void)input_shape_in_process_.insert(input_shape_in_process_.end(), inputs_shape_[0].begin() + count,
inputs_shape_[0].end());
if (new_axis_mask_) {
if (new_axis_mask_ != 0) {
MS_LOG(INFO) << name_ << ": The begin is modified to " << begin_ << ", the end is modified to " << end_
<< ", the strides is modified to " << strides_ << ", the input shape in process is "
<< input_shape_in_process_;

View File

@ -331,7 +331,7 @@ Shapes UnsortedSegmentOpInfo::InferStrategyIndividualMode(const Shapes &in_strat
<< inputs_shape_[0].size() << ", the size of inputs_shape[1] is " << inputs_shape_[1].size();
}
size_t diff_len = inputs_shape_[0].size() - inputs_shape_[1].size();
tmp_strategy.insert(tmp_strategy.end(), diff_len, 1);
(void)tmp_strategy.insert(tmp_strategy.end(), diff_len, 1);
return Shapes({tmp_strategy, in_strategy[1]});
}
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";

View File

@ -133,7 +133,7 @@ static bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies,
return true;
}
static void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_strategy, const int64_t &device_num) {
static void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_strategy) {
auto out_strategy_tuple = out_strategy->cast<ValueNodePtr>();
bool need_default_strategy = false;
size_t out_strategy_size = 0;
@ -176,8 +176,7 @@ static void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &ou
}
}
static Shapes GenerateDefaultStrategyForParam(const CNodePtr &cnode, const std::vector<AnfNodePtr> &parameters,
const Shapes &input_strategy) {
static Shapes GenerateDefaultStrategyForParam(const CNodePtr &cnode, const Shapes &input_strategy) {
auto current_inputs = cnode->inputs();
Shapes elements;
for (size_t i = 1; i < current_inputs.size(); ++i) {
@ -268,7 +267,7 @@ static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const A
}
}
for (auto &cnode : concerned_nodes) {
Shapes ret_strategy = GenerateDefaultStrategyForParam(cnode, parameters, input_strategy);
Shapes ret_strategy = GenerateDefaultStrategyForParam(cnode, input_strategy);
// Set in_strategy
auto strategy = ShapesToValueTuplePtr(ret_strategy);
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
@ -437,7 +436,7 @@ static bool SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfN
std::set_union(input_concerned_cnode.begin(), input_concerned_cnode.end(), parameter_concerned_cnode.begin(),
parameter_concerned_cnode.end(), std::inserter(concerned_cnode, concerned_cnode.end()));
CompleteConcernedCNodeStrategies(concerned_cnode);
SetOutputLayout(func_graph, out_strategy, device_num); // Not in effect currently
SetOutputLayout(func_graph, out_strategy); // Not in effect currently
return true;
}
}

View File

@ -1011,14 +1011,14 @@ bool IsPynativeParallel() {
return (execution_mode == kPynativeMode) && (parallel_mode == kSemiAutoParallel || parallel_mode == kAutoParallel);
}
// compile graph order:
// 1, ParallelParameterContextInitShape/Ckpt/Restore
// 2, PynativeShard: find 'shard' node and set 'pynative_shard' flag for root graph
// 3, PipelineSplit: insert virtual dataset
// 4, StepAutoParallel
// 5, StepParallel
// if IsPynativeParallel() is true, it maybe has some graphs that we no care, so need to check 'pynative_shard' flag
bool IsAutoParallelCareGraph(const FuncGraphPtr &func_graph) {
// compile graph order:
// 1, ParallelParameterContextInitShape/Ckpt/Restore
// 2, PynativeShard: find 'shard' node and set 'pynative_shard' flag for root graph
// 3, PipelineSplit: insert virtual dataset
// 4, StepAutoParallel
// 5, StepParallel
// if IsPynativeParallel() is true, it maybe has some graphs that we no care, so need to check 'pynative_shard' flag
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->has_flag(kSkipAutoParallelCompile)) {
return false;