From 048b88c41c9ea71dcdb7245ff87fd643ddc16eb8 Mon Sep 17 00:00:00 2001 From: yangzhenzhang <285824651@qq.com> Date: Mon, 31 Aug 2020 17:03:04 +0800 Subject: [PATCH] update check strategy value --- .../parallel/ops_info/activation_info.cc | 31 +++------------- .../parallel/ops_info/arithmetic_info.cc | 11 ++---- .../parallel/ops_info/batch_parallel_info.cc | 8 ++--- .../parallel/ops_info/bias_add_info.cc | 11 ++---- .../frontend/parallel/ops_info/concat_info.cc | 11 ++---- .../parallel/ops_info/dropout_do_mask_info.cc | 13 ++----- .../parallel/ops_info/gather_v2_info.cc | 10 ++---- .../parallel/ops_info/gather_v2_p_info.cc | 19 ++-------- .../parallel/ops_info/get_next_info.cc | 8 +---- .../parallel/ops_info/l2_normalize_info.cc | 3 +- .../parallel/ops_info/layer_norm_info.cc | 10 ++---- .../frontend/parallel/ops_info/loss_info.cc | 9 ++--- .../frontend/parallel/ops_info/matmul_info.cc | 2 +- .../frontend/parallel/ops_info/onehot_info.cc | 24 ++----------- .../parallel/ops_info/operator_info.cc | 36 ++++++++++--------- .../parallel/ops_info/operator_info.h | 1 + .../frontend/parallel/ops_info/prelu_info.cc | 10 ++---- .../parallel/ops_info/reduce_method_info.cc | 18 ++-------- .../parallel/ops_info/reshape_info.cc | 16 ++------- .../parallel/ops_info/strided_slice_info.cc | 9 ++--- .../frontend/parallel/ops_info/tile_info.cc | 16 ++------- .../parallel/ops_info/tmp_identity_info.cc | 15 ++------ .../parallel/ops_info/transpose_info.cc | 16 ++------- .../parallel/ops_info/virtual_dataset_info.cc | 9 ++--- 24 files changed, 64 insertions(+), 252 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc index 7e196ba77f2..b6bd16fae07 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc @@ -30,26 +30,12 @@ namespace mindspore { namespace parallel { -Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - return FAILED; - } +Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } - return SUCCESS; -} - -Status Activation::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - return FAILED; - } - - return SUCCESS; -} +Status Activation::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Invalid strategy."; return FAILED; } @@ -153,7 +139,7 @@ Status DropoutInfo::GenerateStrategies(int32_t stage_id) { } Status Softmax::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Invalid strategy."; return FAILED; } @@ -229,14 +215,7 @@ Status Softmax::GetAttrs() { return SUCCESS; } -Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - return FAILED; - } - - return SUCCESS; -} +Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status Softmax::GenerateStrategies(int32_t stage_id) { if (GetAttrs() != SUCCESS) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc index f6b21b0750f..ba5f991261b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc @@ -73,7 +73,7 @@ Strategys ExpendStrategy(const StrategyPtr &strategy) { } Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Invalid strategy."; return FAILED; } @@ -290,14 +290,7 @@ Status ArithmeticBase::InferTensorInfo() { return SUCCESS; } -Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - return FAILED; - } - - return SUCCESS; -} +Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status ArithmeticBase::GenerateStrategies(int32_t stage_id) { Shape input0_split(inputs_shape_[0].size(), 1); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc index ad74d9f3ef7..16cb93270cf 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Invalid strategy."; return FAILED; } @@ -172,11 +172,7 @@ Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) { } Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - return FAILED; - } - return SUCCESS; + return SetCostUnderStrategyBase(strategy); } Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc index bf2deb3c3db..2b1b7a19b2b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Invalid strategy."; return FAILED; } @@ -176,14 +176,7 @@ Status BiasAddInfo::InferTensorInfo() { return SUCCESS; } -Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - - return SUCCESS; -} +Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status BiasAddInfo::GenerateStrategies(int32_t stage_id) { Shape input0_split(inputs_shape_[0].size(), 1); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc index 2b20b27324a..89cae3d181a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc @@ -60,7 +60,7 @@ Status ConcatInfo::GetAttrs() { Status ConcatInfo::CheckStrategy(const StrategyPtr &strategy) { MS_EXCEPTION_IF_NULL(strategy); - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Invalid strategy"; return FAILED; } @@ -197,14 +197,7 @@ void ConcatInfo::ReComputeBatchSplitFlagList() { } } -Status ConcatInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - - return SUCCESS; -} +Status ConcatInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status ConcatInfo::GenerateStrategies(int32_t stage_id) { if (InferAttrs() != SUCCESS) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc index cdbd8f5915c..eb09bd6ca04 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc @@ -50,11 +50,7 @@ Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) { // only check the input[0] Shapes input_shape = {inputs_shape_[0]}; - if (CheckStrategyValue(strategy, input_shape, is_auto_parallel_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Invalid strategy"; - return FAILED; - } - return SUCCESS; + return CheckStrategyValue(strategy, input_shape); } Status DropoutDoMaskInfo::InferDevMatrixShape() { @@ -125,12 +121,7 @@ Status DropoutDoMaskInfo::InferTensorInfo() { } Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed"; - return FAILED; - } - - return SUCCESS; + return SetCostUnderStrategyBase(strategy); } Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc index 7e938993747..2b3b9cef45e 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc @@ -82,7 +82,7 @@ Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } // Only strategy of the first input should be set. - if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Invalid strategy."; return FAILED; } @@ -301,13 +301,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - return SUCCESS; -} +Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } std::shared_ptr GatherV2Info::GenerateBatchStrategies() { if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index d0fa5d2ab7b..64607cd7b80 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -213,12 +213,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) { } Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { return FAILED; } @@ -716,17 +711,7 @@ Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) { return SUCCESS; } -Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} +Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { if (GetAttrs() != SUCCESS) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc index 234882c9b87..5e4d5fa2fd7 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc @@ -240,13 +240,7 @@ Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) { return SUCCESS; } -Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - return FAILED; - } - return SUCCESS; -} +Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status GetNextInfo::GenerateStrategies(int32_t stage_id) { Strategys stra; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc index e662d0e9a40..f6326b39e82 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc @@ -27,8 +27,7 @@ namespace mindspore { namespace parallel { Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - MS_LOG(INFO) << name_ << " : Init success."; + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { return FAILED; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc index 0d761e4e450..8fbe857f819 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc @@ -55,7 +55,7 @@ Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Invalid strategy value"; return FAILED; } @@ -207,13 +207,7 @@ Status LayerNormInfo::InferAsLossDivisor() { return SUCCESS; } -Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Set cost failed"; - return FAILED; - } - return SUCCESS; -} +Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector &sp_vector) { if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc index f8c92fe5f1b..ce7c360b8d5 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace parallel { Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Invalid strategy."; return FAILED; } @@ -200,12 +200,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { } Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - PrintStrategy(strategy); - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - return FAILED; - } - return SUCCESS; + return SetCostUnderStrategyBase(strategy); } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc index 9e4411bee31..dd54b0ddd86 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc @@ -150,7 +150,7 @@ Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions } Status MatMul::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Invalid strategy."; return FAILED; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc index 797bdc8719d..a9c48d6351a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc @@ -55,21 +55,7 @@ Status OneHotInfo::GetAttrs() { } Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { - if (inputs_shape_.size() != 3) { - MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != 1) { - MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size(); - return FAILED; - } - if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}, - is_auto_parallel_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - return FAILED; - } - - return SUCCESS; + return CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}); } Status OneHotInfo::InferDevMatrixShape() { @@ -278,13 +264,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - return SUCCESS; -} +Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } std::shared_ptr OneHotInfo::GenerateBatchStrategies() { CheckGlobalDeviceManager(); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index b7ecb7e34b8..60a1c783ecf 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -33,19 +33,21 @@ namespace mindspore { namespace parallel { -Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool is_auto_parallel) { +Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) { if (strategy == nullptr) { - MS_LOG(ERROR) << "The strategy is null."; + MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; } size_t strategy_size = strategy->GetInputNumber(); size_t inputs_shape_size = inputs_shape.size(); if (strategy_size != inputs_shape_size) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Strategy size: " << strategy_size + << " is not equal to inputs size: " << inputs_shape_size; } else { - MS_LOG(ERROR) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; + MS_LOG(ERROR) << name_ << ": Strategy size: " << strategy_size + << " is not equal to inputs size: " << inputs_shape_size; } return FAILED; } @@ -57,11 +59,11 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap size_t strategy_len = sub_strategy.size(); size_t inputs_len = sub_input_shape.size(); if (strategy_len != inputs_len) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len << ", index: " << i; } else { - MS_LOG(ERROR) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len + MS_LOG(ERROR) << name_ << ": Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len << ", index: " << i; } return FAILED; @@ -70,29 +72,29 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap for (size_t j = 0; j < strategy_len; ++j) { int64_t strategy_value = sub_strategy.at(j); if (strategy_value < MIN_SLICE_NUM) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Invalid strategy value: " << strategy_value; + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy value: " << strategy_value; } else { - MS_LOG(ERROR) << "Invalid strategy value: " << strategy_value; + MS_LOG(ERROR) << name_ << ": Invalid strategy value: " << strategy_value; } return FAILED; } if ((IntToUint(strategy_value) & IntToUint(strategy_value - 1)) != 0) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Invalid Strategy value it is not the power of 2, " << strategy_value; + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid Strategy value it is not the power of 2, " << strategy_value; } else { - MS_LOG(ERROR) << "Invalid Strategy value it is not the power of 2, " << strategy_value; + MS_LOG(ERROR) << name_ << ": Invalid Strategy value it is not the power of 2, " << strategy_value; } return FAILED; } int64_t shape_value = sub_input_shape.at(j); if ((shape_value % strategy_value) != 0) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; } else { - MS_LOG(ERROR) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; + MS_LOG(ERROR) << name_ << ": Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; } return FAILED; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index 7801482c564..ce5dc591318 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -176,6 +176,7 @@ class OperatorInfo { virtual Status GetAttrs() = 0; virtual Status InferTensorInfo() = 0; virtual Status InferDevMatrixShape() = 0; + Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); void SetDeviceListByStrategy(); void SetRepeatedCalcDevMatrix(); Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc index a196860fe7c..53c08716e95 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc @@ -34,7 +34,7 @@ namespace parallel { * the strategy of w should equal to the channel dimension of strategy of A, or equal to 1 */ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Invalid strategy."; return FAILED; } @@ -220,12 +220,6 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - return SUCCESS; -} +Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc index 7459a04dee6..d8982b2176b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc @@ -29,14 +29,7 @@ namespace mindspore { namespace parallel { -Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - return FAILED; - } - - return SUCCESS; -} +Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } Status ReduceMethod::InferDevMatrixShape() { Strategys stra = strategy_->GetInputDim(); @@ -354,14 +347,7 @@ Status ReduceMethod::InferTensorInfo() { return SUCCESS; } -Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - - return SUCCESS; -} +Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status ReduceMethod::GenerateStrategies(int32_t stage_id) { if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc index cca9e761bd7..6d7c7fe3507 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc @@ -29,14 +29,7 @@ namespace mindspore { namespace parallel { -Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - return FAILED; - } - - return SUCCESS; -} +Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } /* * support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of @@ -394,12 +387,7 @@ Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) { } Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - - return SUCCESS; + return SetCostUnderStrategyBase(strategy); } void ReshapeInfo::SetCostForReshapeWithParameter() { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc index f6b1a188c3b..31be2a09b86 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc @@ -98,7 +98,7 @@ Status StridedSliceInfo::GetAttrs() { Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) { MS_EXCEPTION_IF_NULL(strategy); - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Invalid strategy"; return FAILED; } @@ -232,12 +232,7 @@ std::shared_ptr StridedSliceInfo::GenerateBatchStrategies() { } Status StridedSliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - - return SUCCESS; + return SetCostUnderStrategyBase(strategy); } Status StridedSliceInfo::GenerateStrategies(int32_t stage_id) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc index 890402980c9..d8ec6e587c0 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc @@ -67,12 +67,7 @@ Status TileInfo::GetAttrs() { Status TileInfo::CheckStrategy(const StrategyPtr &strategy) { Shapes multiples = {full_multiples_}; - if (CheckStrategyValue(strategy, multiples, is_auto_parallel_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - return FAILED; - } - - return SUCCESS; + return CheckStrategyValue(strategy, multiples); } Status TileInfo::InferDevMatrixShape() { @@ -197,14 +192,7 @@ std::shared_ptr TileInfo::GenerateBatchStrategies() { return GenerateBatchStrategiesBySplitFlag(multiples_shape, split_flag_list_); } -Status TileInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - - return SUCCESS; -} +Status TileInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status TileInfo::GenerateStrategies(int32_t stage_id) { if (InferAttrs() != SUCCESS) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc index 55bfa9810b9..3425e95f4c0 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc @@ -25,11 +25,7 @@ namespace mindspore { namespace parallel { Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": invalid strategy."; - return FAILED; - } - return SUCCESS; + return CheckStrategyValue(strategy, inputs_shape_); } Status TmpIdentityInfo::InferDevMatrixShape() { @@ -98,14 +94,7 @@ Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) { return SUCCESS; } -Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - - return SUCCESS; -} +Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status TmpIdentityInfo::GenerateStrategies(int32_t stage_id) { if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc index 22046b9009a..99cc0dfd986 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc @@ -27,14 +27,7 @@ namespace mindspore { namespace parallel { -Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - return FAILED; - } - - return SUCCESS; -} +Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } Status TransposeInfo::InferDevMatrixShape() { Strategys stra = strategy_->GetInputDim(); @@ -195,12 +188,7 @@ Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) { } Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - - return SUCCESS; + return SetCostUnderStrategyBase(strategy); } Status TransposeInfo::GenerateStrategies(int32_t stage_id) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc index b59aeda4ca4..308a6642405 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc @@ -29,7 +29,7 @@ namespace mindspore { namespace parallel { Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Invalid strategy."; return FAILED; } @@ -143,12 +143,7 @@ void VirtualDatasetInfo::ReComputeBatchSplitFlagList() { } Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - return FAILED; - } - - return SUCCESS; + return SetCostUnderStrategyBase(strategy); } Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {