forked from OSSInnovation/mindspore
!5582 update CheckStrategyValue
Merge pull request !5582 from yangzhenzhang/update-check-strategy-value
This commit is contained in:
commit
79b117fe02
|
@ -30,26 +30,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Activation::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status Activation::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
|
||||
|
||||
Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -153,7 +139,7 @@ Status DropoutInfo::GenerateStrategies(int32_t stage_id) {
|
|||
}
|
||||
|
||||
Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -229,14 +215,7 @@ Status Softmax::GetAttrs() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status Softmax::GenerateStrategies(int32_t stage_id) {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
|
|
|
@ -73,7 +73,7 @@ Strategys ExpendStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -290,14 +290,7 @@ Status ArithmeticBase::InferTensorInfo() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status ArithmeticBase::GenerateStrategies(int32_t stage_id) {
|
||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -172,11 +172,7 @@ Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
return SetCostUnderStrategyBase(strategy);
|
||||
}
|
||||
|
||||
Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) {
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -176,14 +176,7 @@ Status BiasAddInfo::InferTensorInfo() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status BiasAddInfo::GenerateStrategies(int32_t stage_id) {
|
||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
|
|
|
@ -60,7 +60,7 @@ Status ConcatInfo::GetAttrs() {
|
|||
|
||||
Status ConcatInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -197,14 +197,7 @@ void ConcatInfo::ReComputeBatchSplitFlagList() {
|
|||
}
|
||||
}
|
||||
|
||||
Status ConcatInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status ConcatInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status ConcatInfo::GenerateStrategies(int32_t stage_id) {
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
|
|
|
@ -50,11 +50,7 @@ Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
|
||||
// only check the input[0]
|
||||
Shapes input_shape = {inputs_shape_[0]};
|
||||
if (CheckStrategyValue(strategy, input_shape, is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
return CheckStrategyValue(strategy, input_shape);
|
||||
}
|
||||
|
||||
Status DropoutDoMaskInfo::InferDevMatrixShape() {
|
||||
|
@ -125,12 +121,7 @@ Status DropoutDoMaskInfo::InferTensorInfo() {
|
|||
}
|
||||
|
||||
Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return SetCostUnderStrategyBase(strategy);
|
||||
}
|
||||
|
||||
Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) {
|
||||
|
|
|
@ -82,7 +82,7 @@ Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
// Only strategy of the first input should be set.
|
||||
if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -301,13 +301,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() {
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
|
|
|
@ -213,12 +213,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
|
|||
}
|
||||
|
||||
Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Invalid strategy.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
}
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -716,17 +711,7 @@ Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
|
|
|
@ -240,13 +240,7 @@ Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status GetNextInfo::GenerateStrategies(int32_t stage_id) {
|
||||
Strategys stra;
|
||||
|
|
|
@ -27,8 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(INFO) << name_ << " : Init success.";
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy value";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -207,13 +207,7 @@ Status LayerNormInfo::InferAsLossDivisor() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Set cost failed";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> &sp_vector) {
|
||||
if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) {
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -200,12 +200,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) {
|
|||
}
|
||||
|
||||
Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
PrintStrategy(strategy);
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
return SetCostUnderStrategyBase(strategy);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -150,7 +150,7 @@ Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions
|
|||
}
|
||||
|
||||
Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
|
|
@ -55,21 +55,7 @@ Status OneHotInfo::GetAttrs() {
|
|||
}
|
||||
|
||||
Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (inputs_shape_.size() != 3) {
|
||||
MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_shape_.size() != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)},
|
||||
is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)});
|
||||
}
|
||||
|
||||
Status OneHotInfo::InferDevMatrixShape() {
|
||||
|
@ -278,13 +264,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() {
|
||||
CheckGlobalDeviceManager();
|
||||
|
|
|
@ -33,19 +33,21 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool is_auto_parallel) {
|
||||
Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) {
|
||||
if (strategy == nullptr) {
|
||||
MS_LOG(ERROR) << "The strategy is null.";
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is null.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
size_t strategy_size = strategy->GetInputNumber();
|
||||
size_t inputs_shape_size = inputs_shape.size();
|
||||
if (strategy_size != inputs_shape_size) {
|
||||
if (is_auto_parallel) {
|
||||
MS_LOG(DEBUG) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size;
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Strategy size: " << strategy_size
|
||||
<< " is not equal to inputs size: " << inputs_shape_size;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size;
|
||||
MS_LOG(ERROR) << name_ << ": Strategy size: " << strategy_size
|
||||
<< " is not equal to inputs size: " << inputs_shape_size;
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -57,11 +59,11 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
|
|||
size_t strategy_len = sub_strategy.size();
|
||||
size_t inputs_len = sub_input_shape.size();
|
||||
if (strategy_len != inputs_len) {
|
||||
if (is_auto_parallel) {
|
||||
MS_LOG(DEBUG) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len
|
||||
<< ", index: " << i;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len
|
||||
MS_LOG(ERROR) << name_ << ": Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len
|
||||
<< ", index: " << i;
|
||||
}
|
||||
return FAILED;
|
||||
|
@ -70,29 +72,29 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
|
|||
for (size_t j = 0; j < strategy_len; ++j) {
|
||||
int64_t strategy_value = sub_strategy.at(j);
|
||||
if (strategy_value < MIN_SLICE_NUM) {
|
||||
if (is_auto_parallel) {
|
||||
MS_LOG(DEBUG) << "Invalid strategy value: " << strategy_value;
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Invalid strategy value: " << strategy_value;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid strategy value: " << strategy_value;
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy value: " << strategy_value;
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if ((IntToUint(strategy_value) & IntToUint(strategy_value - 1)) != 0) {
|
||||
if (is_auto_parallel) {
|
||||
MS_LOG(DEBUG) << "Invalid Strategy value it is not the power of 2, " << strategy_value;
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Invalid Strategy value it is not the power of 2, " << strategy_value;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid Strategy value it is not the power of 2, " << strategy_value;
|
||||
MS_LOG(ERROR) << name_ << ": Invalid Strategy value it is not the power of 2, " << strategy_value;
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
int64_t shape_value = sub_input_shape.at(j);
|
||||
if ((shape_value % strategy_value) != 0) {
|
||||
if (is_auto_parallel) {
|
||||
MS_LOG(DEBUG) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value;
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Shape " << shape_value << " cannot be divisible by strategy " << strategy_value;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value;
|
||||
MS_LOG(ERROR) << name_ << ": Shape " << shape_value << " cannot be divisible by strategy " << strategy_value;
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
|
|
|
@ -176,6 +176,7 @@ class OperatorInfo {
|
|||
virtual Status GetAttrs() = 0;
|
||||
virtual Status InferTensorInfo() = 0;
|
||||
virtual Status InferDevMatrixShape() = 0;
|
||||
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
|
||||
void SetDeviceListByStrategy();
|
||||
void SetRepeatedCalcDevMatrix();
|
||||
Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
|
||||
|
|
|
@ -34,7 +34,7 @@ namespace parallel {
|
|||
* the strategy of w should equal to the channel dimension of strategy of A, or equal to 1
|
||||
*/
|
||||
Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -220,12 +220,6 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,14 +29,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
|
||||
|
||||
Status ReduceMethod::InferDevMatrixShape() {
|
||||
Strategys stra = strategy_->GetInputDim();
|
||||
|
@ -354,14 +347,7 @@ Status ReduceMethod::InferTensorInfo() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status ReduceMethod::GenerateStrategies(int32_t stage_id) {
|
||||
if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {
|
||||
|
|
|
@ -29,14 +29,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
|
||||
|
||||
/*
|
||||
* support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of
|
||||
|
@ -394,12 +387,7 @@ Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return SetCostUnderStrategyBase(strategy);
|
||||
}
|
||||
|
||||
void ReshapeInfo::SetCostForReshapeWithParameter() {
|
||||
|
|
|
@ -98,7 +98,7 @@ Status StridedSliceInfo::GetAttrs() {
|
|||
|
||||
Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -232,12 +232,7 @@ std::shared_ptr<Strategys> StridedSliceInfo::GenerateBatchStrategies() {
|
|||
}
|
||||
|
||||
Status StridedSliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return SetCostUnderStrategyBase(strategy);
|
||||
}
|
||||
|
||||
Status StridedSliceInfo::GenerateStrategies(int32_t stage_id) {
|
||||
|
|
|
@ -67,12 +67,7 @@ Status TileInfo::GetAttrs() {
|
|||
|
||||
Status TileInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
Shapes multiples = {full_multiples_};
|
||||
if (CheckStrategyValue(strategy, multiples, is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return CheckStrategyValue(strategy, multiples);
|
||||
}
|
||||
|
||||
Status TileInfo::InferDevMatrixShape() {
|
||||
|
@ -197,14 +192,7 @@ std::shared_ptr<Strategys> TileInfo::GenerateBatchStrategies() {
|
|||
return GenerateBatchStrategiesBySplitFlag(multiples_shape, split_flag_list_);
|
||||
}
|
||||
|
||||
Status TileInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status TileInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status TileInfo::GenerateStrategies(int32_t stage_id) {
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
|
|
|
@ -25,11 +25,7 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
return CheckStrategyValue(strategy, inputs_shape_);
|
||||
}
|
||||
|
||||
Status TmpIdentityInfo::InferDevMatrixShape() {
|
||||
|
@ -98,14 +94,7 @@ Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status TmpIdentityInfo::GenerateStrategies(int32_t stage_id) {
|
||||
if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {
|
||||
|
|
|
@ -27,14 +27,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
|
||||
|
||||
Status TransposeInfo::InferDevMatrixShape() {
|
||||
Strategys stra = strategy_->GetInputDim();
|
||||
|
@ -195,12 +188,7 @@ Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return SetCostUnderStrategyBase(strategy);
|
||||
}
|
||||
|
||||
Status TransposeInfo::GenerateStrategies(int32_t stage_id) {
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -143,12 +143,7 @@ void VirtualDatasetInfo::ReComputeBatchSplitFlagList() {
|
|||
}
|
||||
|
||||
Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return SetCostUnderStrategyBase(strategy);
|
||||
}
|
||||
|
||||
Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {
|
||||
|
|
Loading…
Reference in New Issue