!45631 Insert in_strategy between input and target op

Merge pull request !45631 from liuluobin/shard_in_stra_identity
This commit is contained in:
i-robot 2022-11-30 01:50:12 +00:00 committed by Gitee
commit 6733c6a529
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
51 changed files with 118 additions and 1439 deletions

View File

@ -40,10 +40,7 @@ class AddNInfo : public OperatorInfo {
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override {
infer_strategy_mode_ = SAME_MODE;
return SUCCESS;
}
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;

View File

@ -40,10 +40,7 @@ class ArithmeticBase : public OperatorInfo {
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override {
infer_strategy_mode_ = BROADCAST_MODE;
return SUCCESS;
}
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;

View File

@ -74,7 +74,6 @@ Status BatchNormInfo::GetAttrs() {
MS_LOG(INFO) << name_ << ": The is_traing is " << is_training_ << ", epsilon is " << epsilon_ << ", momentum is "
<< momentum_ << ", data format is " << format_;
infer_strategy_mode_ = INDIVIDUAL_MODE;
return SUCCESS;
}
@ -335,38 +334,6 @@ std::vector<StrategyPtr> BatchNormInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector;
}
// in_strategy: ((N, C, H, W), (), (), (), ()) return: ((N, C, H, W), (C), (C), (C), (C))
// in_strategy: ((), (C), (C), (C), (C)) return: ((1, C, 1, 1), (C), (C), (C), (C))
// in_strategy: ((), (C), (), (C), (C)) throw exception
Shapes BatchNormInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != 5) {
MS_LOG(EXCEPTION) << name_ << ": The size of in strategy must be 5, but got " << in_strategy.size();
}
if ((in_strategy[2] != in_strategy[1]) || (in_strategy[3] != in_strategy[1]) || (in_strategy[4] != in_strategy[1])) {
MS_LOG(EXCEPTION) << name_ << ": The last 4 strategy must be equal, but got " << in_strategy;
}
Shape channel_strategy;
if (!in_strategy[0].empty()) {
if (in_strategy[0].size() != 4 && in_strategy[0].size() != 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be 4 or 2, but got " << in_strategy[0].size();
}
channel_strategy = {in_strategy[0][1]};
return Shapes({in_strategy[0], channel_strategy, channel_strategy, channel_strategy, channel_strategy});
}
for (size_t i = 1; i < in_strategy.size(); ++i) {
if (!in_strategy[i].empty()) {
channel_strategy = in_strategy[i];
break;
}
}
Shape tmp_strategy(inputs_shape_[0].size(), 1);
tmp_strategy[1] = channel_strategy[0];
return Shapes({tmp_strategy, channel_strategy, channel_strategy, channel_strategy, channel_strategy});
}
REGISTER(BatchNormInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -47,7 +47,6 @@ class BatchNormInfo : public OperatorInfo {
Status InferTensorMap() override;
void InferReplaceOps() override;
Status InferAsLossDivisor() override;
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
private:
bool is_training_ = false;

View File

@ -102,30 +102,6 @@ std::vector<StrategyPtr> BiasAddInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector;
}
// in_strategy: ((N, C, H, W), ()), inputs shapes: ((n, c, h, w), (c)), return: ((N, C, H, W), (C))
// in_strategy: ((), (C)), inputs shapes: ((n, c, h, w), (c)), return: ((1, C, 1, 1), (C))
// in_strategy: ((N, C), ()), inputs shapes: ((n, c), (c)), return: ((N, C), (C))
// in_strategy: ((), (C)), inputs shapes: ((n, c), (c)), return: ((1, C), (C))
Shapes BiasAddInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 2, but got " << in_strategy.size();
}
if (!in_strategy[0].empty()) {
if (in_strategy[0].size() != 2 && in_strategy[0].size() != 4) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be 2 or 4, but got " << in_strategy[0].size();
}
return Shapes({in_strategy[0], {in_strategy[0][1]}});
}
if (in_strategy[1].size() != 1) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[1] must be 1, but got " << in_strategy[1].size();
}
Shape tmp(inputs_shape_[0].size(), 1);
tmp[1] = in_strategy[1][0];
return Shapes({tmp, in_strategy[1]});
}
REGISTER(BiasAddInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -42,15 +42,11 @@ class BiasAddInfo : public OperatorInfo {
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override {
infer_strategy_mode_ = INDIVIDUAL_MODE;
return SUCCESS;
}
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
};
} // namespace parallel
} // namespace mindspore

View File

@ -42,10 +42,7 @@ class BoundingBoxEncodeInfo : public OperatorInfo {
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override {
infer_strategy_mode_ = SAME_MODE;
return SUCCESS;
}
Status GetAttrs() override { return SUCCESS; }
Status InferForwardCommunication() override { return SUCCESS; };
private:

View File

@ -28,7 +28,6 @@ Status CdistInfo::GetAttrs() {
MS_LOG(ERROR) << "Dimension of each input must be 2 or 3, but got dimension is " << input_dims_ << ".";
return FAILED;
}
infer_strategy_mode_ = INDIVIDUAL_MODE;
return SUCCESS;
}
@ -117,29 +116,6 @@ void CdistInfo::ReComputeBatchSplitFlagList() {
}
}
// in_strategy: ((B, P, 1), ()), inputs shape: ((b, p, m), (b, r, m)), return: ((B, P, 1), (B, 1, 1))
// in_strategy: ((), (B, R, 1)), inputs shape: ((b, p, m), (b, r, m)), return: ((B, 1, 1), (B, R, 1))
Shapes CdistInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 2, but got " << in_strategy.size();
}
Shape tmp_strategy;
if (!in_strategy[0].empty()) {
tmp_strategy = Shape(inputs_shape_[1].size(), 1);
tmp_strategy[0] = in_strategy[0][0];
return Shapes({in_strategy[0], tmp_strategy});
}
if (!in_strategy[1].empty()) {
tmp_strategy = Shape(inputs_shape_[0].size(), 1);
tmp_strategy[0] = in_strategy[1][0];
return Shapes({tmp_strategy, in_strategy[1]});
}
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";
}
REGISTER(CdistInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -43,7 +43,6 @@ class CdistInfo : public OperatorInfo {
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferForwardCommunication() override { return SUCCESS; }
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
private:
size_t input_dims_ = 0;

View File

@ -56,7 +56,6 @@ Status ConcatInfo::GetAttrs() {
}
axis_ = LongToSize(axis);
infer_strategy_mode_ = SAME_MODE;
return SUCCESS;
}

View File

@ -134,7 +134,6 @@ Status Conv2DInfo::GetAttrsBase() {
// group
group_ = GetIntAttr(GROUP);
infer_strategy_mode_ = INDIVIDUAL_MODE;
return SUCCESS;
}
@ -1039,68 +1038,6 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector;
}
// conv2d: ((N, C-in, H, W), (C-out, C-in, k1, k2))
// in_strategy: ((N, C, H, W), ()), return: ((N, C, H, W), (1, C, 1, 1))
// in_strategy: ((), (C0, C1, 1, 1)), return: ((1, C1, 1, 1), (C0, C1, 1, 1))
Shapes Conv2DInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 2, but got " << in_strategy.size();
}
Shape tmp_strategy;
if (!in_strategy[0].empty()) {
if (in_strategy[0].size() != inputs_shape_[0].size()) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be " << inputs_shape_[0].size() << ", but got "
<< in_strategy[0].size();
}
tmp_strategy = Shape(inputs_shape_[1].size(), 1);
tmp_strategy[1] = in_strategy[0][1];
return Shapes({in_strategy[0], tmp_strategy});
}
if (!in_strategy[1].empty()) {
if (in_strategy[1].size() != inputs_shape_[1].size()) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[1] must be " << inputs_shape_[1].size() << ", but got "
<< in_strategy[1].size();
}
tmp_strategy = Shape(inputs_shape_[0].size(), 1);
tmp_strategy[1] = in_strategy[1][1];
return Shapes({tmp_strategy, in_strategy[1]});
}
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";
}
// conv2d-transpose: ((N, C-out, H, W), (C-out, C-in, k1, k2))
// in_strategy: ((N, C, H, W), ()), return: ((N, C, H, W), (C, 1, 1, 1))
// in_strategy: ((), (C0, C1, 1, 1)), return: ((1, C0, 1, 1), (C0, C1, 1, 1))
Shapes Conv2DBackpropInputInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 2, but got " << in_strategy.size();
}
Shape tmp_strategy;
if (!in_strategy[0].empty()) {
if (in_strategy[0].size() != 4) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be 4, but got " << in_strategy[0].size();
}
tmp_strategy = Shape(inputs_shape_[1].size(), 1);
tmp_strategy[0] = in_strategy[0][1];
return Shapes({in_strategy[0], tmp_strategy});
}
if (!in_strategy[1].empty()) {
if (in_strategy[1].size() != 4) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[1] must be 4, but got " << in_strategy[1].size();
}
tmp_strategy = Shape(inputs_shape_[0].size(), 1);
tmp_strategy[1] = in_strategy[1][0];
return Shapes({tmp_strategy, in_strategy[1]});
}
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";
}
Status Conv2DBackpropInputInfo::GetOutShape() {
if (input_value_.size() != 3) {
MS_LOG(ERROR) << name_ << ": The size of input value must be 3, but got " << input_value_.size();

View File

@ -146,7 +146,6 @@ class Conv2DInfo : public OperatorInfo {
virtual int64_t ComputeOverlapBottomSizeByRankBias(int64_t rank_bias);
virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
virtual int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
private:
Status CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy);
@ -179,7 +178,6 @@ class Conv2DBackpropInputInfo : public Conv2DInfo {
int64_t ComputeOverlapBottomSizeByRankBias(int64_t rank_bias) override;
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) override;
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias) override;
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
private:
Shape out_shape_;

View File

@ -51,7 +51,6 @@ Status GammaInfo::GetAttrs() {
MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_;
return FAILED;
}
infer_strategy_mode_ = INDEPENDENT_MODE;
return SUCCESS;
}

View File

@ -167,7 +167,6 @@ Status GatherInfo::GetAttrs() {
dynamic_shape_indices_ = true;
}
#endif
infer_strategy_mode_ = INDEPENDENT_MODE;
return SUCCESS;
}

View File

@ -54,7 +54,6 @@ Status GatherDInfo::GetAttrs() {
dim_ = LongToSize(dim);
MS_LOG(INFO) << name_ << ": The dim is " << dim_;
infer_strategy_mode_ = SAME_MODE;
return SUCCESS;
}

View File

@ -41,10 +41,7 @@ class GatherNdInfo : public OperatorInfo {
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override {
infer_strategy_mode_ = INDEPENDENT_MODE;
return SUCCESS;
}
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;

View File

@ -38,10 +38,7 @@ class InplaceOpBase : public OperatorInfo {
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override {
infer_strategy_mode_ = SAME_MODE;
return SUCCESS;
}
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;

View File

@ -36,10 +36,7 @@ class IOUInfo : public OperatorInfo {
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
protected:
Status GetAttrs() override {
infer_strategy_mode_ = INDEPENDENT_MODE;
return SUCCESS;
}
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;

View File

@ -25,7 +25,6 @@ namespace mindspore {
namespace parallel {
Status KLDivLossInfo::GetAttrs() {
reduction_ = GetStringAttr(REDUCTION);
infer_strategy_mode_ = SAME_MODE;
return SUCCESS;
}

View File

@ -59,7 +59,6 @@ Status LayerNormInfo::GetAttrs() {
axis = axis + dim;
}
begin_norm_axis_ = LongToSize(axis);
infer_strategy_mode_ = INDIVIDUAL_MODE;
return SUCCESS;
}
@ -249,67 +248,6 @@ Status LayerNormInfo::InitShapes() {
return SUCCESS;
}
// in_strategy: ((A, B, 1, 1), (), ()), begin_norm_axis_: 2, Shapes: ((a, b, c, d), (b, c, d), (b, c, d))
// return: ((A, B, 1, 1), (B, 1, 1), (B, 1, 1))
// in_strategy: ((), (B, 1, 1), (B, 1, 1)), begin_norm_axis_: 2, Shapes: ((a, b, c, d), (b, c, d), (b, c, d))
// return: ((1, B, 1, 1), (B, 1, 1), (B, 1, 1))
Shapes LayerNormInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != 3) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
}
if (in_strategy[1] != in_strategy[2]) {
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[1] must be equal to the in_strategy[2], but the in_strategy[1] is"
<< in_strategy[1] << ", the in_strategy[2] is " << in_strategy[2];
}
if (gamma_shape_ != beta_shape_) {
MS_LOG(EXCEPTION) << name_ << ": The gamma's shape must be equal to the beta's shape, but the gamma's shape is"
<< gamma_shape_ << ", the beta's shape is " << beta_shape_;
}
if (input_shape_.size() < gamma_shape_.size()) {
MS_LOG(EXCEPTION)
<< name_ << ": The input's shape size cannot smaller than gamma's shape size, but the input's shape size is"
<< input_shape_.size() << ", the gamma's shape size is " << gamma_shape_.size();
}
if (!in_strategy[0].empty()) {
if (in_strategy[0].size() != input_shape_.size()) {
MS_LOG(EXCEPTION)
<< name_
<< ": The size of in_strategy[0] must be equal to the size of inputs_shape[0], but the in_strategy[0] is"
<< in_strategy[0] << ", the inputs_shape[0] is " << input_shape_;
}
if (input_shape_.size() == gamma_shape_.size()) {
return Shapes({in_strategy[0], in_strategy[0], in_strategy[0]});
} else {
size_t diff_len = input_shape_.size() - gamma_shape_.size();
Shape gamma_strategy(in_strategy[0].begin() + diff_len, in_strategy[0].end());
return Shapes({in_strategy[0], gamma_strategy, gamma_strategy});
}
}
if (!in_strategy[1].empty()) {
if (in_strategy[1].size() != gamma_shape_.size()) {
MS_LOG(EXCEPTION)
<< name_
<< ": The size of in_strategy[1] must be equal to the size of inputs_shape[1], but the in_strategy[1] is"
<< in_strategy[1] << ", the inputs_shape[1] is " << gamma_shape_;
}
if (input_shape_.size() == gamma_shape_.size()) {
return Shapes({in_strategy[1], in_strategy[1], in_strategy[1]});
} else {
size_t diff_len = input_shape_.size() - gamma_shape_.size();
Shape tmp_strategy = in_strategy[1];
(void)tmp_strategy.insert(tmp_strategy.begin(), diff_len, 1);
return Shapes({tmp_strategy, in_strategy[1], in_strategy[1]});
}
}
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0], in_strategy[1] and in_strategy[2] are empty";
}
REGISTER(LayerNormInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -57,7 +57,6 @@ class LayerNormInfo : public OperatorInfo {
Status CreateInputTensorMap(size_t input_index);
Status GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> &sp_vector);
Status InitShapes();
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
private:
size_t begin_norm_axis_;

View File

@ -65,7 +65,6 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() {
MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
return FAILED;
}
infer_strategy_mode_ = SAME_MODE;
return SUCCESS;
}

View File

@ -125,7 +125,6 @@ Status MatMulBase::GetAttrs() {
<< mat_a_dimension_ << ", the dim of mat_b is " << mat_b_dimension_;
}
infer_strategy_mode_ = INDIVIDUAL_MODE;
return SUCCESS;
}
@ -514,54 +513,6 @@ std::shared_ptr<Strategies> BatchMatMulInfo::GenerateBatchStrategies() {
Status MatMulBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
Shapes MatMulBase::InferStrategyIndividualMode(const Shapes &in_strategy) {
// if the transpose_b is false:
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, C, D), (A, B, D, 1))
// in_strategy: ((), (A, B, D, E)), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, 1, D), (A, B, D, E))
// if the transpose_b is true:
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, C, D), (A, B, 1, D))
// in_strategy: ((), (A, B, E, D)), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, 1, D), (A, B, E, D))
if (in_strategy.size() != 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of in strategy must be 2, but got " << in_strategy.size();
}
if (in_strategy[0].empty() && in_strategy[1].empty()) {
MS_LOG(EXCEPTION) << name_ << ": The in strategy is empty";
}
if (!in_strategy[0].empty() && !in_strategy[1].empty()) {
return in_strategy;
}
if (!in_strategy[0].empty() && in_strategy[0].size() != inputs_shape_[0].size()) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] is " << in_strategy[0].size()
<< ", but the size of the inputs_shape_[0] is " << inputs_shape_[0].size();
}
if (!in_strategy[1].empty() && in_strategy[1].size() != inputs_shape_[1].size()) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[1] is " << in_strategy[1].size()
<< ", but the size of the inputs_shape_[1] is " << inputs_shape_[1].size();
}
Shapes ret_strategy = InferStrategyBroadcastMode(in_strategy);
if (transpose_b_) {
if (in_strategy[0].empty()) {
ret_strategy[0][ret_strategy[0].size() - 2] = 1;
} else {
ret_strategy[1][ret_strategy[1].size() - 2] = 1;
}
} else {
if (in_strategy[0].empty()) {
ret_strategy[0][ret_strategy[0].size() - 2] = 1;
ret_strategy[0][ret_strategy[0].size() - 1] = ret_strategy[1][ret_strategy[1].size() - 2];
} else {
ret_strategy[1][ret_strategy[1].size() - 2] = ret_strategy[0][ret_strategy[0].size() - 1];
ret_strategy[1][ret_strategy[1].size() - 1] = 1;
}
}
return ret_strategy;
}
// PCL matmul
ReplaceGraphPtr MatMul::replace_graph(const CNodePtr &cnode) {
if (!candidate_flag_) {

View File

@ -48,7 +48,6 @@ class MatMulBase : public OperatorInfo {
Status InferTensorMap() override;
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
Status GetAttrs() override;
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
bool candidate_flag_ = false;
bool transpose_a_ = false;

View File

@ -47,7 +47,6 @@ Status OneHotInfo::GetAttrs() {
MS_LOG(ERROR) << name_ << ": Axis " << axis_ << " is out of range[-1, 1].";
return FAILED;
}
infer_strategy_mode_ = INDIVIDUAL_MODE;
return SUCCESS;
}
@ -248,24 +247,6 @@ std::shared_ptr<Strategies> OneHotInfo::GenerateBatchStrategies() {
return std::make_shared<Strategies>(strategy_v);
}
Shapes OneHotInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != 3) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
}
if (in_strategy[0].empty()) {
Shape strategy = {stage_device_size_, 1};
Shape empty_strategy;
return Shapes({strategy, empty_strategy, empty_strategy});
}
if (in_strategy[0].size() != 1) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be 1, but got " << in_strategy[0].size();
}
return Shapes({{in_strategy[0][0], 1}, {}, {}});
}
REGISTER(OneHotInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -49,7 +49,6 @@ class OneHotInfo : public OperatorInfo {
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status ExtractInputInfo();
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
private:
Status ComputeReplaceGraph(const CNodePtr &cnode);

View File

@ -2150,137 +2150,6 @@ float OperatorInfo::GetFloatAttr(const std::string &attr_name) {
return attr_iter->second->cast<FP32ImmPtr>()->value();
}
// in_strategy: ((A, B, C, D), ()), return: ((A, B, C, D), (A, B, C, D))
// in_strategy: ((), (A, B, C, D)), return: ((A, B, C, D), (A, B, C, D))
Shapes OperatorInfo::InferStrategySameMode(const Shapes &in_strategy) const {
Shape value;
for (auto &ele : in_strategy) {
if (!ele.empty()) {
value = ele;
break;
}
}
return Shapes(in_strategy.size(), value);
}
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (e, f, g)), return: ((A, B, C, D), (1, 1, 1))
// in_strategy: ((), (E, F, G)), inputs shape: ((a, b, c, d), (e, f, g)), return: ((1, 1, 1, 1), (E, F, G))
Shapes OperatorInfo::InferStrategyIndependentMode(const Shapes &in_strategy) {
if (in_strategy.size() != inputs_shape_.size()) {
MS_LOG(EXCEPTION)
<< name_ << ": The size of strategies must be equal to the size of inputs shape, but the size of strategies is "
<< in_strategy.size() << ", and the size of inputs shape is " << inputs_shape_.size();
}
Shapes ret;
for (size_t i = 0; i < in_strategy.size(); ++i) {
if (in_strategy[i].empty()) {
(void)ret.emplace_back(Shape(inputs_shape_[i].size(), 1));
continue;
}
(void)ret.emplace_back(in_strategy[i]);
}
return ret;
}
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, c, d)), return: ((A, B, C, D), (A, B, C, D))
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (b, c, d)), return: ((A, B, C, D), (B, C, D))
// in_strategy: ((), (B, C, D)), inputs shape: ((a, b, c, d), (b, c, d)), return: ((1, B, C, D), (B, C, D))
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (1, c, d)), return: ((A, B, C, D), (1, C, D))
Shapes OperatorInfo::InferStrategyBroadcastMode(const Shapes &in_strategy) {
Shapes ret = InferStrategySameMode(in_strategy);
if (ret.size() != inputs_shape_.size()) {
MS_LOG(EXCEPTION)
<< name_ << ": The size of strategies must be equal to the size of inputs shape, but the size of strategies is "
<< ret.size() << ", and the size of inputs shape is " << inputs_shape_.size();
}
// handle the broadcast
// alignment length
for (size_t i = 0; i < ret.size(); ++i) {
size_t strategy_size = ret[i].size();
size_t shape_size = inputs_shape_[i].size();
size_t diff_len = strategy_size > shape_size ? strategy_size - shape_size : shape_size - strategy_size;
if (strategy_size > shape_size) {
// strategy is (A, B, C, D), and shape is (c, d) -> updated strategy is (C, D)
(void)ret[i].erase(ret[i].begin(), ret[i].begin() + static_cast<different_type>(diff_len));
} else if (strategy_size < shape_size) {
// strategy is (C, D), and shape is (a, b, c, d) -> updated strategy is (1, 1, C, D)
(void)ret[i].insert(ret[i].begin(), diff_len, 1);
}
}
// handle the 1 shape value
for (size_t i = 0; i < ret.size(); ++i) {
// strategy is (A, B, C, D), and shape is (1, b, c, d) -> updated strategy is (1, B, C, D)
for (size_t j = 0; j < ret[i].size(); ++j) {
if (inputs_shape_[i][j] == 1) {
ret[i][j] = 1;
}
}
}
return ret;
}
Shapes OperatorInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
MS_LOG(EXCEPTION) << name_ << ": The in strategy is " << in_strategy << ", need to override this function ";
}
Shapes OperatorInfo::GenerateFullStrategyBase(const Shapes &in_strategy) {
// there is no empty in the in_strategy
auto it = std::find_if(in_strategy.begin(), in_strategy.end(), [](const Shape &ele) { return ele.empty(); });
if (it == in_strategy.end()) {
MS_LOG(INFO) << name_ << ": There is no empty in the input strategy, return to itself: " << in_strategy;
return in_strategy;
}
// the in_strategy are all empty, generate the data parallel strategy
auto item = std::find_if(in_strategy.begin(), in_strategy.end(), [](const Shape &ele) { return !ele.empty(); });
if (item == in_strategy.end()) {
std::shared_ptr<Strategies> dp_strategy_ptr = GenerateBatchStrategies();
MS_EXCEPTION_IF_NULL(dp_strategy_ptr);
MS_LOG(INFO) << name_
<< ": The in strategy are all empty, generate the data parallel strategy: " << *dp_strategy_ptr;
return *dp_strategy_ptr;
}
// generate the full strategy from non empty strategy
Shapes ret;
switch (infer_strategy_mode_) {
case SAME_MODE:
ret = InferStrategySameMode(in_strategy);
break;
case BROADCAST_MODE:
ret = InferStrategyBroadcastMode(in_strategy);
break;
case INDEPENDENT_MODE:
ret = InferStrategyIndependentMode(in_strategy);
break;
case INDIVIDUAL_MODE:
ret = InferStrategyIndividualMode(in_strategy);
break;
case INVALID_MODE:
default:
MS_LOG(EXCEPTION) << name_ << ": The invalid mode for infer strategy";
}
MS_LOG(INFO) << name_ << ": The origin strategy is " << in_strategy << ", and the return strategy is " << ret;
return ret;
}
Shapes OperatorInfo::GenerateFullStrategy(const Shapes &in_strategy) {
if (InferAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
}
Shapes ret = GenerateFullStrategyBase(in_strategy);
StrategyPtr strategy_ptr = NewStrategy(0, ret);
if (CheckStrategy(strategy_ptr) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Invalid strategy";
}
return ret;
}
std::vector<ValuePtr> GetValueSequence(const ValuePtr &sequence) {
MS_EXCEPTION_IF_NULL(sequence);
std::vector<ValuePtr> ret;

View File

@ -103,7 +103,6 @@ class OperatorInfo {
const OperatorCostPtr &operator_cost() const { return operator_cost_; }
void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; }
virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0;
Shapes GenerateFullStrategy(const Shapes &in_strategy);
virtual std::shared_ptr<Strategies> GenerateBatchStrategies();
virtual void ReComputeBatchSplitFlagList();
@ -226,7 +225,6 @@ class OperatorInfo {
// needed by rec_parser
std::string type_;
bool is_last_node_ = false;
InferStrategyMode infer_strategy_mode_ = INVALID_MODE;
virtual Status CheckStrategy(const StrategyPtr &strategy) = 0;
virtual Status InferTensorMap() = 0;
virtual Status InferForwardCommunication() = 0;
@ -236,11 +234,6 @@ class OperatorInfo {
virtual Status InferTensorInfo();
virtual void InferReplaceOps() {}
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
virtual Shapes InferStrategyIndividualMode(const Shapes &in_strategy);
Shapes GenerateFullStrategyBase(const Shapes &in_strategy);
Shapes InferStrategySameMode(const Shapes &in_strategy) const;
Shapes InferStrategyBroadcastMode(const Shapes &in_strategy);
Shapes InferStrategyIndependentMode(const Shapes &in_strategy);
Status CheckStrategyByVector(const Shapes &strategy, const Shapes &inputs_shape);
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
void SetRepeatedCalcDevMatrix();

View File

@ -55,7 +55,6 @@ Status StackInfo::GetAttrs() {
axis = axis + dim;
}
axis_ = LongToSize(axis);
infer_strategy_mode_ = SAME_MODE;
return SUCCESS;
}

View File

@ -89,7 +89,6 @@ Status PReLUInfo::GetAttrs() {
<< outputs_shape_.size() << " is wrong.";
return FAILED;
}
infer_strategy_mode_ = INDIVIDUAL_MODE;
return SUCCESS;
}
@ -109,34 +108,6 @@ std::vector<StrategyPtr> PReLUInfo::GenerateOpStrategies(int64_t stage_id) {
Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
// in_strategy: ((A, B, C), ()), Shapes: ((a, b, c), (b)), return: ((A, B, C), (B))
// in_strategy: ((A, B, C), ()), Shapes: ((a, b, c), (1)), return: ((A, B, C), (1))
// in_strategy: ((), (B)), Shapes: ((a, b, c), (b)), return: ((1, B, 1), (B))
Shapes PReLUInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
}
if (!in_strategy[0].empty()) {
if (in_strategy[0].size() < 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be larger than 1, but got "
<< in_strategy[0].size();
}
if (inputs_shape_[1][0] > 1) {
return Shapes({in_strategy[0], {in_strategy[0][1]}});
} else {
return Shapes({in_strategy[0], {1}});
}
}
if (!in_strategy[1].empty()) {
Shape tmp(inputs_shape_[0].size(), 1);
tmp[1] = in_strategy[1][0];
return Shapes({tmp, in_strategy[1]});
}
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";
}
REGISTER(PReLUInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -47,7 +47,6 @@ class PReLUInfo : public OperatorInfo {
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
private:
Dimensions input_strategy_;

View File

@ -29,7 +29,6 @@ Status RandomChoiceWithMaskInfo::GetAttrs() {
if (attrs_.find(SEED2) != attrs_.end()) {
seed2_ = GetValue<int64_t>(attrs_[SEED2]);
}
infer_strategy_mode_ = INDEPENDENT_MODE;
return SUCCESS;
}

View File

@ -34,7 +34,6 @@ Status ROIAlignInfo::GetAttrs() {
}
(void)roi_align_attrs.emplace_back(std::make_pair(attr_key, attrs_[attr_key]));
}
infer_strategy_mode_ = INDEPENDENT_MODE;
return SUCCESS;
}

View File

@ -169,27 +169,6 @@ void ScatterNdOpsInfo::ReComputeBatchSplitFlagList() {
}
}
// in_strategy: ((A, B, C, D), (), ()), Shapes: ((a, b, c, d), (e, f), (e, f, c, d))
// return: ((A, B, C, D), (1, 1), (1, 1, C, D))
// return: throw exception
Shapes ScatterNdOpsInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != SIZE_THREE) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
}
if (in_strategy[INDEX_ZERO].empty()) {
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] must be not empty, you should set the first input strategy.";
}
Shape indices_strategy, updates_strategy;
indices_strategy = Shape(inputs_shape_[INDEX_ONE].size(), 1);
updates_strategy = Shape(inputs_shape_[INDEX_TWO].size(), 1);
for (size_t i = 0; i < in_strategy[0].size() - gather_dims_size_; ++i) {
updates_strategy[indices_strategy.size() - 1 + i] = in_strategy[INDEX_ZERO][i + gather_dims_size_];
}
return Shapes({in_strategy[INDEX_ZERO], indices_strategy, updates_strategy});
}
ReplaceGraphPtr ScatterNdOpsInfo::replace_graph(const CNodePtr &cnode) {
if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << " replace graph failed";

View File

@ -58,7 +58,6 @@ class ScatterNdOpsInfo : public OperatorInfo {
int64_t slice_size_ = 0;
size_t gather_dims_size_ = 1;
GenerateGraph gen_g_ = GenerateGraph(attrs_);
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
};
using ScatterNdOpsInfoPtr = std::shared_ptr<ScatterNdOpsInfo>;

View File

@ -176,46 +176,6 @@ std::vector<StrategyPtr> ScatterOpsInfo::GenerateOpStrategies(int64_t stage_id)
return sp_vector;
}
// in_strategy: ((A, B, C, D), (), ()), Shapes: ((a, b, c, d), (e, f), (e, f, b, c, d))
// return: ((A, B, C, D), (1, 1), (1, 1, B, C, D))
// in_strategy: ((), (1, 1), (E, F, B, C, D)), Shapes: ((a, b, c, d), (e, f), (e, f, b, c, d))
// return: ((A, B, C, D), (1, 1), (E, F, B, C, D))
// in_strategy: ((), (), (E, F, B, C, D)), Shapes: ((a, b, c, d), (e, f), (e, f, b, c, d))
// return: throw exception
Shapes ScatterOpsInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != 3) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
}
if (in_strategy[1].empty() != in_strategy[2].empty()) {
MS_LOG(EXCEPTION)
<< name_
<< ": The in_strategy[1] and in_strategy[2] must be all empty or all non empty, but the in_strategy[1] is "
<< in_strategy[1] << ", the in_strategy[2] is " << in_strategy[2];
}
Shape x_strategy, indices_strategy, updates_strategy;
indices_strategy = Shape(inputs_shape_[1].size(), 1);
if (!in_strategy[0].empty()) {
updates_strategy = in_strategy[0];
(void)updates_strategy.erase(updates_strategy.begin());
(void)updates_strategy.insert(updates_strategy.begin(), inputs_shape_[1].size(), 1);
return Shapes({in_strategy[0], indices_strategy, updates_strategy});
}
if (!in_strategy[2].empty()) {
if (std::accumulate(in_strategy[1].begin(), in_strategy[1].end(), 1, std::multiplies<int64_t>()) != 1) {
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[1] must be fill with 1, but got " << in_strategy[1];
}
x_strategy = in_strategy[2];
(void)x_strategy.erase(x_strategy.begin(),
x_strategy.begin() + static_cast<different_type>(inputs_shape_[1].size() - 1));
return Shapes({x_strategy, indices_strategy, in_strategy[2]});
}
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0], in_strategy[1] and in_strategy[2] are empty";
}
REGISTER(ScatterUpdateInfo);
REGISTER(ScatterMaxInfo);
REGISTER(ScatterMinInfo);

View File

@ -41,16 +41,12 @@ class ScatterOpsInfo : public OperatorInfo {
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override {
infer_strategy_mode_ = INDIVIDUAL_MODE;
return SUCCESS;
}
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override { return SUCCESS; } // the scatter_update only use in eval/predict
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
};
class ScatterUpdateInfo : public ScatterOpsInfo {

View File

@ -41,10 +41,7 @@ class SelectInfo : public OperatorInfo {
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override {
infer_strategy_mode_ = SAME_MODE;
return SUCCESS;
}
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;

View File

@ -55,7 +55,6 @@ Status UniformRealInfo::GetAttrs() {
MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_;
return FAILED;
}
infer_strategy_mode_ = INDEPENDENT_MODE;
return SUCCESS;
}

View File

@ -54,7 +54,6 @@ Status UnsortedSegmentOpInfo::GetAttrs() {
MS_LOG(ERROR) << name_ << ": the number of segments should be non negative value.";
return FAILED;
}
infer_strategy_mode_ = INDIVIDUAL_MODE;
return SUCCESS;
}
@ -311,32 +310,6 @@ Status UnsortedSegmentMaxInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
return SUCCESS;
}
// in_strategy: ((A, B, C), ()), Shapes: ((a, b, c), (a, b)), return: ((A, B, C), (A, B))
// in_strategy: ((), (A, B)), Shapes: ((a, b, c), (a, b)), return: ((A, B, 1), (A, B))
Shapes UnsortedSegmentOpInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
}
if (!in_strategy[0].empty()) {
return Shapes({in_strategy[0], Shape(in_strategy[0].begin(), in_strategy[0].begin() + inputs_shape_[1].size())});
}
if (!in_strategy[1].empty()) {
Shape tmp_strategy = in_strategy[1];
if (inputs_shape_[0].size() < inputs_shape_[1].size()) {
MS_LOG(EXCEPTION) << name_
<< ": The size of inputs_shape[0] can not smaller than the size of inputs_shape[1], but the "
"size of inputs_shape[0] is "
<< inputs_shape_[0].size() << ", the size of inputs_shape[1] is " << inputs_shape_[1].size();
}
size_t diff_len = inputs_shape_[0].size() - inputs_shape_[1].size();
(void)tmp_strategy.insert(tmp_strategy.end(), diff_len, 1);
return Shapes({tmp_strategy, in_strategy[1]});
}
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";
}
REGISTER(UnsortedSegmentSumInfo);
REGISTER(UnsortedSegmentProdInfo);
REGISTER(UnsortedSegmentMinInfo);

View File

@ -57,7 +57,6 @@ class UnsortedSegmentOpInfo : public OperatorInfo {
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
};
class UnsortedSegmentSumInfo : public UnsortedSegmentOpInfo {

View File

@ -31,14 +31,22 @@
namespace mindspore {
namespace parallel {
namespace {
using ExpectFunc = std::function<bool(const CNodePtr &)>;
}
static void GenerateDefaultStrategy(const ValueNodePtr &axes, const std::vector<AnfNodePtr> &nodes,
std::vector<std::vector<int64_t>> *default_strategy) {
const size_t device_num, std::vector<std::vector<int64_t>> *default_strategy) {
auto strategies = axes->value()->cast<ValueTuplePtr>()->value();
size_t i = 0;
for (auto &strategy : strategies) {
auto node = nodes[i];
if (strategy->isa<None>()) {
(void)default_strategy->emplace_back(Shape());
auto node_size = common::AnfAlgo::GetOutputInferShape(node, 0).size();
std::vector<int64_t> current_d_strategy(node_size, 1);
if (!current_d_strategy.empty()) {
current_d_strategy[0] = SizeToLong(device_num);
}
(void)default_strategy->emplace_back(std::move(current_d_strategy));
} else {
(void)default_strategy->emplace_back(GetValue<Shape>(strategy));
}
@ -46,26 +54,13 @@ static void GenerateDefaultStrategy(const ValueNodePtr &axes, const std::vector<
}
}
// Generate strategies like ((), (), ..., ())
Shapes GenerateEmptyStrategies(const CNodePtr &cnode) {
auto shape_list = ExtractShape(cnode);
if (shape_list.empty()) {
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << " failed to extract shape.";
}
return Shapes(shape_list[0].size(), Shape());
}
static bool CheckOneDimensionalIntTuple(const ValuePtr &value_ptr) {
if (!value_ptr->isa<ValueTuple>()) {
return false;
}
auto elements = value_ptr->cast<ValueTuplePtr>()->value();
for (auto &element : elements) {
if (!element->isa<Int64Imm>()) {
return false;
}
}
return true;
return std::all_of(elements.begin(), elements.end(),
[](const ValuePtr &element) { return element->isa<Int64Imm>(); });
}
static bool CheckLayout(const ValueNodePtr &axes, bool *need_default_strategy, size_t *axes_size) {
@ -83,12 +78,6 @@ static bool CheckLayout(const ValueNodePtr &axes, bool *need_default_strategy, s
return true;
}
static Shapes GenerateFullStrategy(const Shapes &current_strategy, const CNodePtr &cnode) {
OperatorInfoPtr op_info = CreateOperatorInfo(cnode);
MS_EXCEPTION_IF_NULL(op_info);
return op_info->GenerateFullStrategy(current_strategy);
}
static void GetInputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *input_nodes) {
auto parameters = func_graph->parameters();
for (auto &parameter : parameters) {
@ -99,7 +88,7 @@ static void GetInputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr
}
}
static bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies, const int64_t &device_num) {
static bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies, const int64_t device_num) {
for (size_t i = 0; i < strategies.size(); ++i) {
auto strategy = strategies[i];
int64_t required_num = 1;
@ -119,34 +108,6 @@ static bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies,
return true;
}
// Generate strategy for cnode by input_strategy.
// For the i-th input:
// 1. If it is specified in input_strategy, the strategy in input_strategy is used;
// 2. Otherwise, its strategy is assigned as ()
static Shapes GenerateDefaultStrategiesForCNode(const CNodePtr &cnode, const Shapes &input_strategy) {
auto current_inputs = cnode->inputs();
Shapes elements;
for (size_t i = 1; i < current_inputs.size(); ++i) {
auto current_input = current_inputs[i];
if (current_input->isa<ValueNode>()) {
auto current_value = current_input->cast<ValueNodePtr>()->value();
if (!current_value->isa<mindspore::tensor::Tensor>()) {
continue;
}
}
if (IsPrimitiveCNode(current_input, prim::kPrimTupleGetItem)) {
auto tuple_getitem_cnode = current_input->cast<CNodePtr>();
auto tuple_index = tuple_getitem_cnode->input(2);
auto value_node = tuple_index->cast<ValueNodePtr>();
auto index = GetValue<int64_t>(value_node->value());
elements.push_back(input_strategy[index]);
} else {
(void)elements.emplace_back(Shape());
}
}
return elements;
}
static ValueTuplePtr ShapesToValueTuplePtr(const Shapes &shapes) {
std::vector<ValuePtr> value_list;
(void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(value_list),
@ -162,23 +123,22 @@ static Shapes ValueTuplePtrToShapes(const ValueTuplePtr &value_tuple_ptr) {
return shapes;
}
AnfNodeIndexSet FindAnfNodeIndexSetToInsertStrategy(const FuncGraphPtr &func_graph, const AnfNodeIndexSet &node_users) {
AnfNodeIndexSet FindAnfNodeIndexSetToInsertStrategy(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const ExpectFunc &filter_func) {
FuncGraphManagerPtr manager = func_graph->manager();
AnfNodeIndexSet ret_set;
std::queue<std::pair<AnfNodePtr, int>> bfs_list;
auto node_users = manager->node_users()[node];
std::queue<std::pair<AnfNodePtr, int>> bfs_queuq;
(void)std::for_each(node_users.begin(), node_users.end(),
[&bfs_list](const std::pair<AnfNodePtr, int> &user) { bfs_list.push(user); });
while (!bfs_list.empty()) {
auto user = bfs_list.front();
bfs_list.pop();
CNodePtr cnode = user.first->cast<CNodePtr>();
// If the cnode is not a splittable operator, apply strategy to the next cnode
if (!IsSplittableOperator(GetPrimName(cnode)) || IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset) ||
IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
[&bfs_queuq](const std::pair<AnfNodePtr, int> &user) { bfs_queuq.push(user); });
while (!bfs_queuq.empty()) {
auto user = bfs_queuq.front();
bfs_queuq.pop();
auto cnode = user.first->cast<CNodePtr>();
if (!filter_func(cnode)) {
auto tmp_users = manager->node_users()[cnode];
(void)std::for_each(tmp_users.begin(), tmp_users.end(),
[&bfs_list](const std::pair<AnfNodePtr, int> &user) { bfs_list.push(user); });
[&bfs_queuq](const std::pair<AnfNodePtr, int> &user) { bfs_queuq.push(user); });
continue;
}
ret_set.insert(user);
@ -186,6 +146,25 @@ AnfNodeIndexSet FindAnfNodeIndexSetToInsertStrategy(const FuncGraphPtr &func_gra
return ret_set;
}
bool IsSettingStrategyByInsertIdentity(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::string &param_name) {
FuncGraphManagerPtr manager = func_graph->manager();
auto node_users = manager->node_users()[cnode];
for (const auto &user : node_users) {
auto user_node = user.first;
if (IsPrimitiveCNode(user_node, prim::kPrimIdentity)) {
auto attrs = GetCNodePrimitive(user_node)->attrs();
if (StrategyFound(attrs)) {
auto origin_strategies = ValueTuplePtrToShapes(attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
MS_LOG(WARNING) << "For " << param_name << ", its strategy has been set to " << origin_strategies.at(0)
<< ", the relevant settings in input_strategy will be ignored";
return true;
}
}
}
return false;
}
// New a primitive for cnode and set in_strategy to it.
void SetStrategyToCNode(const CNodePtr &cnode, const Shapes &strategies) {
auto strategy = ShapesToValueTuplePtr(strategies);
@ -193,7 +172,7 @@ void SetStrategyToCNode(const CNodePtr &cnode, const Shapes &strategies) {
MS_EXCEPTION_IF_NULL(prim);
PrimitivePtr new_prim;
if (prim->isa<PrimitivePy>()) {
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
auto prim_py = prim->cast<PrimitivePyPtr>();
MS_EXCEPTION_IF_NULL(prim_py);
new_prim = std::make_shared<PrimitivePy>(*prim_py);
} else {
@ -205,13 +184,12 @@ void SetStrategyToCNode(const CNodePtr &cnode, const Shapes &strategies) {
ValuePtr new_prim_value = MakeValue(new_prim);
ValueNodePtr new_prim_value_node = NewValueNode(new_prim_value);
AnfNodePtr new_prim_anf_node = new_prim_value_node->cast<AnfNodePtr>();
auto new_prim_anf_node = new_prim_value_node->cast<AnfNodePtr>();
MS_EXCEPTION_IF_NULL(new_prim_anf_node);
cnode->set_input(0, new_prim_anf_node);
}
static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_strategy,
const int64_t &device_num) {
void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_strategy, const int64_t device_num) {
auto in_strategy_tuple = in_strategy->cast<ValueNodePtr>();
bool need_default_strategy = false;
size_t in_strategy_size = 0;
@ -227,14 +205,14 @@ static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const A
}
std::vector<std::vector<int64_t>> input_strategy;
if (need_default_strategy) {
GenerateDefaultStrategy(in_strategy_tuple, input_nodes, &input_strategy);
GenerateDefaultStrategy(in_strategy_tuple, input_nodes, device_num, &input_strategy);
} else {
input_strategy = GetValue<std::vector<std::vector<int64_t>>>(in_strategy_tuple->value());
}
if (!CheckDeviceNum(input_strategy, device_num)) {
MS_LOG(EXCEPTION) << "check device number failed";
}
std::set<CNodePtr> concerned_nodes;
FuncGraphManagerPtr manager = func_graph->manager();
auto parameters = func_graph->parameters();
for (size_t i = 0; i < parameters.size(); ++i) {
@ -250,105 +228,76 @@ static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const A
<< " is not equal to in_strategy dimension: " << input_strategy[i].size() << " at index " << i;
}
AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
auto to_insert_nodes_set = FindAnfNodeIndexSetToInsertStrategy(func_graph, param_sub_set);
auto to_insert_nodes_set = FindAnfNodeIndexSetToInsertStrategy(
func_graph, parameter, [](const CNodePtr &cnode) { return IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem); });
if (to_insert_nodes_set.empty()) {
MS_LOG(EXCEPTION) << "For input: \"" << parameter->fullname_with_scope()
<< "\", failed to find node to insert strategy.";
}
for (auto &node : to_insert_nodes_set) {
CNodePtr param_cnode = node.first->cast<CNodePtr>();
auto param_attrs = GetCNodePrimitive(param_cnode)->attrs();
if (StrategyFound(param_attrs)) {
auto origin_strategies = ValueTuplePtrToShapes(param_attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
MS_LOG(WARNING) << "For " << param_cnode->fullname_with_scope() << ", its in_strategy has been set to "
<< origin_strategies << ", the relevant settings in input_strategy will be ignored";
auto tuple_get_item_cnode = node.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_get_item_cnode);
if (IsSettingStrategyByInsertIdentity(func_graph, tuple_get_item_cnode, parameter->fullname_with_scope())) {
continue;
}
(void)concerned_nodes.insert(param_cnode);
// Setting strategy by insert identity.
// e.g TupleGetItem(parameter, index) -> identity{in_strategy=[input_strategy[index]}
auto identity_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimIdentity), tuple_get_item_cnode});
auto tuple_get_item_cnode_abstract = tuple_get_item_cnode->abstract();
MS_EXCEPTION_IF_NULL(tuple_get_item_cnode_abstract);
identity_cnode->set_abstract(tuple_get_item_cnode_abstract->Clone());
manager->Replace(tuple_get_item_cnode, identity_cnode);
// Get corresponding param_layout
auto tuple_index = tuple_get_item_cnode->input(2);
auto value_node = tuple_index->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto index = GetValue<int64_t>(value_node->value());
Shapes current_strategies = {input_strategy[index]};
SetStrategyToCNode(identity_cnode, current_strategies);
}
}
for (auto &cnode : concerned_nodes) {
Shapes ret_strategy = GenerateDefaultStrategiesForCNode(cnode, input_strategy);
SetStrategyToCNode(cnode, ret_strategy);
}
return concerned_nodes;
}
static std::set<CNodePtr> SetParameterLayout(const FuncGraphPtr &root, const FuncGraphPtr &func_graph,
const std::set<CNodePtr> &input_concerned_node) {
void SetParameterLayout(const FuncGraphPtr &root, const FuncGraphPtr &func_graph) {
FuncGraphManagerPtr manager = func_graph->manager();
auto root_parameters = root->parameters();
std::set<CNodePtr> concerned_cnode;
for (auto param : root_parameters) {
for (const auto &param : root_parameters) {
auto parameter = param->cast<ParameterPtr>();
auto param_info = parameter->param_info();
if (param_info == nullptr || param_info->param_strategy().empty()) {
// Do not set param_strategy, skip it.
continue;
}
auto param_strategy = parameter->param_info()->param_strategy();
auto param_name = parameter->param_info()->name();
auto param_strategy = param_info->param_strategy();
auto param_name = param_info->name();
AnfNodeIndexSet users = manager->node_users()[parameter];
auto to_insert_nodes_set = FindAnfNodeIndexSetToInsertStrategy(func_graph, users);
for (auto user : to_insert_nodes_set) {
CNodePtr target_cnode = user.first->cast<CNodePtr>();
Shapes current_strategies;
if (input_concerned_node.find(target_cnode) == input_concerned_node.end()) {
// If target_cnode is not involve inputs, insert an identity between Load and target_cnode,
// and setting layout into identity.
// e.g Load(param) -> identity{in_strategy} -> target_cnode
auto pre_cnode = target_cnode->input(user.second)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
if (IsPrimitiveCNode(pre_cnode, prim::kPrimCast)) {
pre_cnode = pre_cnode->inputs().at(kIndex1)->cast<CNodePtr>();
}
if (!IsPrimitiveCNode(pre_cnode, prim::kPrimLoad)) {
MS_LOG(EXCEPTION) << "The operator type of the " << user.second << "-th input in "
<< target_cnode->fullname_with_scope() << " must be 'Load', but got "
<< GetCNodePrimitive(pre_cnode)->ToString();
}
auto identity_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimIdentity), pre_cnode});
auto pre_cnode_abstract = pre_cnode->abstract();
MS_EXCEPTION_IF_NULL(pre_cnode_abstract);
identity_cnode->set_abstract(pre_cnode_abstract->Clone());
manager->Replace(pre_cnode, identity_cnode);
target_cnode = identity_cnode;
current_strategies = {param_strategy};
} else {
// Setting layout into target_cnode directly.
PrimitivePtr prim = GetCNodePrimitive(target_cnode);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
if (StrategyFound(attrs)) {
current_strategies = ValueTuplePtrToShapes(attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
} else {
current_strategies = GenerateEmptyStrategies(target_cnode);
}
current_strategies[user.second - 1] = param_strategy;
(void)concerned_cnode.insert(target_cnode);
auto to_insert_nodes_set = FindAnfNodeIndexSetToInsertStrategy(
func_graph, parameter, [](const CNodePtr &cnode) { return IsPrimitiveCNode(cnode, prim::kPrimLoad); });
for (const auto &user : to_insert_nodes_set) {
auto load_cnode = user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(load_cnode);
if (IsSettingStrategyByInsertIdentity(func_graph, load_cnode, param_name)) {
continue;
}
SetStrategyToCNode(target_cnode, current_strategies);
MS_LOG(DEBUG) << "The layout of \"" << param_name << "\" has been set to the " << user.second << "th of "
<< target_cnode->fullname_with_scope() << "'s in_strategy. Current strategies is "
<< current_strategies;
}
}
return concerned_cnode;
}
void CompleteConcernedCNodeStrategies(std::set<CNodePtr> concerned_cnode) {
for (auto cnode : concerned_cnode) {
PrimitivePtr prim = GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
Shapes current_strategies = ValueTuplePtrToShapes(attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
Shapes full_strategies = GenerateFullStrategy(current_strategies, cnode);
attrs[parallel::IN_STRATEGY] = ShapesToValueTuplePtr(full_strategies);
(void)prim->SetAttrs(attrs);
MS_LOG(INFO) << cnode->fullname_with_scope() << ": Completion strategies success. " << current_strategies << " -> "
<< full_strategies << "(origin_strategies -> completion_strategies)";
// Setting param_layout by insert identity. e.g Load(param) -> identity{in_strategy=[param_layout]}
auto identity_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimIdentity), load_cnode});
auto load_cnode_abstract = load_cnode->abstract();
MS_EXCEPTION_IF_NULL(load_cnode_abstract);
identity_cnode->set_abstract(load_cnode_abstract->Clone());
manager->Replace(load_cnode, identity_cnode);
Shapes current_strategies = {param_strategy};
SetStrategyToCNode(identity_cnode, current_strategies);
MS_LOG(DEBUG) << "The layout of \"" << param_name << "\" has been set to "
<< identity_cnode->fullname_with_scope() << ". Current strategies is " << current_strategies;
}
}
}
static bool SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const int64_t &device_num) {
const int64_t device_num) {
constexpr size_t kShardFnIndex = 1;
constexpr size_t kShardInStrategyIndex = 2;
for (auto &node : all_nodes) {
@ -366,13 +315,8 @@ static bool SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfN
if (HasNestedMetaFg(func_graph)) {
return false;
}
std::set<CNodePtr> concerned_cnode;
auto input_concerned_cnode = SetInputLayout(func_graph, in_strategy, device_num);
auto parameter_concerned_cnode = SetParameterLayout(root, func_graph, input_concerned_cnode);
(void)std::set_union(input_concerned_cnode.begin(), input_concerned_cnode.end(),
parameter_concerned_cnode.begin(), parameter_concerned_cnode.end(),
std::inserter(concerned_cnode, concerned_cnode.end()));
CompleteConcernedCNodeStrategies(concerned_cnode);
SetInputLayout(func_graph, in_strategy, device_num);
SetParameterLayout(root, func_graph);
return true;
}
}

View File

@ -199,6 +199,10 @@ void ExecuteActionForMindRT(const ResourcePtr &resource) {
void ModifyOutputNode(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
const auto &used_forward_nodes = func_graph->used_forward_nodes();
std::vector<AnfNodePtr> used_forward_nodes_sorted_list(used_forward_nodes.begin(), used_forward_nodes.end());
std::sort(
used_forward_nodes_sorted_list.begin(), used_forward_nodes_sorted_list.end(),
[](const AnfNodePtr &a, const AnfNodePtr &b) { return a->fullname_with_scope() < b->fullname_with_scope(); });
// Get original output node and abstract
auto original_output_node = func_graph->output();
@ -209,7 +213,7 @@ void ModifyOutputNode(const FuncGraphPtr &func_graph) {
// Create a new make tuple node to hold all forward used nodes.
abstract::AbstractBasePtrList added_abs_list;
std::vector<AnfNodePtr> added_node_list{NewValueNode(prim::kPrimMakeTuple)};
std::for_each(used_forward_nodes.begin(), used_forward_nodes.end(),
std::for_each(used_forward_nodes_sorted_list.begin(), used_forward_nodes_sorted_list.end(),
[&added_abs_list, &added_node_list](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
added_node_list.push_back(node);

View File

@ -374,7 +374,7 @@ def train_feed(num_classes, expect_out):
model.train(2, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
loss_value = np.array(parallel_callback.loss_list)
print(loss_value)
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
assert np.allclose(loss_value, expect_out, 0.001, 0.001)
def test_train_feed_ascend():
@ -391,7 +391,7 @@ def test_train_feed_ascend():
dataset_strategy="data_parallel")
np.random.seed(42)
set_seed(42)
train_feed(num_classes=65536, expect_out=[11.32993, 10.7269535])
train_feed(num_classes=65536, expect_out=[11.275127, 8.720392])
def test_train_feed_gpu():
@ -406,6 +406,6 @@ def test_train_feed_gpu():
context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL,
search_mode="sharding_propagation", device_num=8,
dataset_strategy="data_parallel")
np.random.seed(42)
set_seed(42)
train_feed(num_classes=65536, expect_out=[53.35976, 54.689503])
np.random.seed(1)
set_seed(1)
train_feed(num_classes=65536, expect_out=[53.538628, 54.83031])

View File

@ -1,112 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <list>
#include <vector>
#include "common/common_test.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/ops_info/arithmetic_info.h"
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace parallel {
class AddInfo;
using AddInfoPtr = std::shared_ptr<AddInfo>;
AddInfoPtr add, add1;
class TestInferStrategyBroadcastMode : public UT::Common {
public:
TestInferStrategyBroadcastMode() {}
void SetUp();
void TearDown() {}
};
void TestInferStrategyBroadcastMode::SetUp() {
RankList dev_list;
for (int32_t i = 0; i < 34; i++) {
dev_list.push_back(i);
}
RankList stage_map;
stage_map.push_back(32);
stage_map.push_back(2);
int32_t local_dev = 0;
// create a new g_device_manager
g_device_manager = std::make_shared<DeviceManager>();
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
mindspore::HashMap<std::string, ValuePtr> attr;
Shapes inputs_shape = {{32, 64, 96}, {32, 64, 96}};
Shapes outputs_shape = {{32, 64, 96}};
add = std::make_shared<AddInfo>("tensoradd_info", inputs_shape, outputs_shape, attr);
Shapes inputs_shape1 = {{16, 32, 48}, {32, 48}};
Shapes outputs_shape1 = {{16, 32, 48}};
add1 = std::make_shared<AddInfo>("tensoradd_info", inputs_shape1, outputs_shape1, attr);
}
/// Feature: infer strategy for broadcast mode
/// Description: the in strategy is {{2, 4, 4}, {}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
/// Expectation: the return strategy is {{2, 4, 4}, {2, 4, 4}}
TEST_F(TestInferStrategyBroadcastMode, GenerateFullStrategy1) {
Strategies in_strategy = {{2, 4, 4}, {}};
Strategies ret = add->GenerateFullStrategy(in_strategy);
Strategies expect = {{2, 4, 4}, {2, 4, 4}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for broadcast mode
/// Description: the in strategy is {{}, {2, 4, 4}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
/// Expectation: the return strategy is {{2, 4, 4}, {2, 4, 4}}
TEST_F(TestInferStrategyBroadcastMode, GenerateFullStrategy2) {
Strategies in_strategy = {{}, {2, 4, 4}};
Strategies ret = add->GenerateFullStrategy(in_strategy);
Strategies expect = {{2, 4, 4}, {2, 4, 4}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for broadcast mode
/// Description: the in strategy is {{2, 4, 4}, {}}, the in shapes is {{16, 32, 48}, {32, 48}}
/// Expectation: the return strategy is {{2, 4, 4}, {4, 4}}
TEST_F(TestInferStrategyBroadcastMode, GenerateFullStrategy3) {
Strategies in_strategy = {{2, 4, 4}, {}};
Strategies ret = add1->GenerateFullStrategy(in_strategy);
Strategies expect = {{2, 4, 4}, {4, 4}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for broadcast mode
/// Description: the in strategy is {{}, {4, 4}}, the in shapes is {{16, 32, 48}, {32, 48}}
/// Expectation: the return strategy is {{1, 4, 4}, {4, 4}}
TEST_F(TestInferStrategyBroadcastMode, GenerateFullStrategy4) {
Strategies in_strategy = {{}, {4, 4}};
Strategies ret = add1->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 4, 4}, {4, 4}};
ASSERT_EQ(ret, expect);
}
} // namespace parallel
} // namespace mindspore

View File

@ -1,86 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <list>
#include <vector>
#include "common/common_test.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/ops_info/gathernd_info.h"
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace parallel {
class GatherNdInfo;
using GatherNdInfoPtr = std::shared_ptr<GatherNdInfo>;
GatherNdInfoPtr gathernd;
class TestInferStrategyIndependentMode : public UT::Common {
public:
TestInferStrategyIndependentMode() {}
void SetUp();
void TearDown() {}
};
void TestInferStrategyIndependentMode::SetUp() {
RankList dev_list;
for (int32_t i = 0; i < 34; i++) {
dev_list.push_back(i);
}
RankList stage_map;
stage_map.push_back(32);
stage_map.push_back(2);
int32_t local_dev = 0;
// create a new g_device_manager
g_device_manager = std::make_shared<DeviceManager>();
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
mindspore::HashMap<std::string, ValuePtr> attr;
Shapes inputs_shape = {{32, 64, 96}, {32, 64, 96}};
Shapes outputs_shape = {{32, 64, 96}};
gathernd = std::make_shared<GatherNdInfo>("gathernd_info", inputs_shape, outputs_shape, attr);
}
/// Feature: infer strategy for independent mode
/// Description: the in strategy is {{1, 1, 1}, {}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
/// Expectation: the return strategy is {{1, 1, 1}, {1, 1, 1}}
TEST_F(TestInferStrategyIndependentMode, GenerateFullStrategy1) {
Strategies in_strategy = {{1, 1, 1}, {}};
Strategies ret = gathernd->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 1, 1}, {1, 1, 1}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for independent mode
/// Description: the in strategy is {{}, {2, 4, 1}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
/// Expectation: the return strategy is {{1, 1, 1}, {2, 4, 1}}
TEST_F(TestInferStrategyIndependentMode, GenerateFullStrategy2) {
Strategies in_strategy = {{}, {2, 4, 1}};
Strategies ret = gathernd->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 1, 1}, {2, 4, 1}};
ASSERT_EQ(ret, expect);
}
} // namespace parallel
} // namespace mindspore

View File

@ -1,264 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <list>
#include <vector>
#include "common/common_test.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/ops_info/layer_norm_info.h"
#include "frontend/parallel/ops_info/batchnorm_info.h"
#include "frontend/parallel/ops_info/bias_add_info.h"
#include "frontend/parallel/ops_info/scatter_ops_info.h"
#include "frontend/parallel/ops_info/conv2d_info.h"
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace parallel {
class LayerNormInfo;
class BiasAddInfo;
class BatchNormInfo;
class ScatterUpdateInfo;
class Conv2DInfo;
using LayerNormInfoPtr = std::shared_ptr<LayerNormInfo>;
using BiasAddInfoPtr = std::shared_ptr<BiasAddInfo>;
using BatchNormInfoPtr = std::shared_ptr<BatchNormInfo>;
using ScatterUpdateInfoPtr = std::shared_ptr<ScatterUpdateInfo>;
using Conv2DInfoPtr = std::shared_ptr<Conv2DInfo>;
LayerNormInfoPtr layer_norm;
BiasAddInfoPtr bias_add;
BatchNormInfoPtr batch_norm;
ScatterUpdateInfoPtr scatter_update;
Conv2DInfoPtr conv2d;
class TestInferStrategyIndividualMode : public UT::Common {
public:
TestInferStrategyIndividualMode() {}
void SetUp();
void TearDown() {}
};
void TestInferStrategyIndividualMode::SetUp() {
RankList dev_list;
for (int32_t i = 0; i < 64; i++) {
dev_list.push_back(i);
}
RankList stage_map;
stage_map.push_back(64);
int32_t local_dev = 0;
// create a new g_device_manager
g_device_manager = std::make_shared<DeviceManager>();
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
// layer_norm
ValuePtr begin_norm_axis = MakeValue(std::int64_t(3));
mindspore::HashMap<std::string, ValuePtr> attr_1 = {{"begin_norm_axis", begin_norm_axis}};
Shapes ln_inputs_shape = {{16, 32, 64, 96}, {32, 64, 96}, {32, 64, 96}};
Shapes ln_outputs_shape = {{16, 32, 64, 96}, {16, 32, 1, 1}, {16, 32, 1, 1}};
layer_norm = std::make_shared<LayerNormInfo>("layernorm_info", ln_inputs_shape, ln_outputs_shape, attr_1);
// bias_add
mindspore::HashMap<std::string, ValuePtr> attr_2;
Shapes ba_inputs_shape = {{64, 96}, {96}};
Shapes ba_outputs_shape = {{64, 96}};
bias_add = std::make_shared<BiasAddInfo>("biasadd_info", ba_inputs_shape, ba_outputs_shape, attr_2);
// batch_norm
ValuePtr is_training = MakeValue(true);
ValuePtr epsilon = MakeValue(std::float_t(1.0));
ValuePtr momentum = MakeValue(std::float_t(1.0));
ValuePtr format = MakeValue("NCHW");
mindspore::HashMap<std::string, ValuePtr> attr_3 = {{"is_training", is_training},
{"epsilon", epsilon},
{"momentum", momentum},
{"format", format}};
Shapes bn_inputs_shape = {{64, 96, 32, 16}, {96}, {96}, {96}, {96}};
Shapes bn_outputs_shape = {{64, 96, 32, 16}, {96}, {96}, {96}, {96}};
batch_norm = std::make_shared<BatchNormInfo>("batchnorm_info", bn_inputs_shape, bn_outputs_shape, attr_3);
// scatter_update
mindspore::HashMap<std::string, ValuePtr> attr_4;
Shapes su_inputs_shape = {{16, 32, 64, 96}, {128, 256}, {128, 256, 32, 64, 96}};
Shapes su_outputs_shape = {{16, 32, 64, 96}};
scatter_update = std::make_shared<ScatterUpdateInfo>("scatterupdate_info", su_inputs_shape, su_outputs_shape, attr_4);
// conv2d
ValuePtr out_channel = MakeValue(std::int64_t(10));
ValuePtr kernel_size = MakeValue(std::vector<int64_t>{4, 4});
ValuePtr mode = MakeValue(std::int64_t(1));
ValuePtr pad_mode = MakeValue(std::int64_t(1));
ValuePtr pad_list = MakeValue(std::vector<int64_t>{1, 1, 1, 1});
ValuePtr stride = MakeValue(std::vector<int64_t>{1, 1, 2, 2});
ValuePtr dilation = MakeValue(std::vector<int64_t>{1, 1, 1, 1});
ValuePtr group = MakeValue(std::int64_t(1));
mindspore::HashMap<std::string, ValuePtr> attr_5 = {{"out_channel", out_channel},
{"kernel_size", kernel_size},
{"pad_mode", pad_mode},
{"mode", mode},
{"pad_list", pad_list},
{"stride", stride},
{"dilation", dilation},
{"group", group},
{"format", format}};
Shapes conv_inputs_shape = {{128, 2, 16, 16}, {10, 2, 4, 4}};
Shapes conv_outputs_shape = {{128, 10, 8, 8}};
conv2d = std::make_shared<Conv2DInfo>("conv2d_info", conv_inputs_shape, conv_outputs_shape, attr_5);
}
/// Feature: infer strategy for layer_norm
/// Description: the in strategy is {{2, 4, 8, 1}, {}, {}}
/// Expectation: the return strategy is {{2, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}}
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy1) {
Strategies in_strategy = {{2, 4, 8, 1}, {}, {}};
Strategies ret = layer_norm->GenerateFullStrategy(in_strategy);
Strategies expect = {{2, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for layer_norm
/// Description: the in strategy is {{}, {4, 8, 1}, {4, 8, 1}}
/// Expectation: the return strategy is {{1, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}}
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy2) {
Strategies in_strategy = {{}, {4, 8, 1}, {4, 8, 1}};
Strategies ret = layer_norm->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for layer_norm
/// Description: the in strategy is {{}, {4, 8, 1}, {}}
/// Expectation: throw exception
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy3) {
Strategies in_strategy = {{}, {4, 8, 1}, {}};
ASSERT_ANY_THROW(layer_norm->GenerateFullStrategy(in_strategy));
}
/// Feature: infer strategy for bias_add
/// Description: the in strategy is {{4, 8}, {}}
/// Expectation: the return strategy is {{4, 8}, {8}}
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy4) {
Strategies in_strategy = {{4, 8}, {}};
Strategies ret = bias_add->GenerateFullStrategy(in_strategy);
Strategies expect = {{4, 8}, {8}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for bias_add
/// Description: the in strategy is {{}, {8}}
/// Expectation: the return strategy is {{1, 8}, {8}}
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy5) {
Strategies in_strategy = {{}, {8}};
Strategies ret = bias_add->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 8}, {8}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for batch_norm
/// Description: the in strategy is {{2, 4, 8, 16}, {}, {}, {}, {}}
/// Expectation: the return strategy is {{2, 4, 8, 16}, {4}, {4}, {4}, {4}}
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy6) {
Strategies in_strategy = {{2, 4, 8, 16}, {}, {}, {}, {}};
Strategies ret = batch_norm->GenerateFullStrategy(in_strategy);
Strategies expect = {{2, 4, 8, 16}, {4}, {4}, {4}, {4}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for batch_norm
/// Description: the in strategy is {{}, {4}, {4}, {4}, {4}}
/// Expectation: the return strategy is {{1, 4, 1, 1}, {4}, {4}, {4}, {4}}
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy7) {
Strategies in_strategy = {{}, {4}, {4}, {4}, {4}};
Strategies ret = batch_norm->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 4, 1, 1}, {4}, {4}, {4}, {4}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for batch_norm
/// Description: the in strategy is {{}, {4}, {}, {}, {4}}
/// Expectation: throw exception
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy8) {
Strategies in_strategy = {{}, {4}, {}, {}, {4}};
ASSERT_ANY_THROW(batch_norm->GenerateFullStrategy(in_strategy));
}
/// Feature: infer strategy for scatter_update
/// Description: the in strategy is {{1, 4, 8, 1}, {}, {}}
/// Expectation: the return strategy is {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}}
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy9) {
Strategies in_strategy = {{1, 4, 8, 1}, {}, {}};
Strategies ret = scatter_update->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for scatter_update
/// Description: the in strategy is {{}, {1, 1}, {1, 1, 4, 8, 1}}
/// Expectation: the return strategy is {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}}
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy10) {
Strategies in_strategy = {{}, {1, 1}, {1, 1, 4, 8, 1}};
Strategies ret = scatter_update->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for scatter_update
/// Description: the in strategy is {{}, {1, 1}, {}}
/// Expectation: throw exception
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy11) {
Strategies in_strategy = {{}, {1, 1}, {}};
ASSERT_ANY_THROW(scatter_update->GenerateFullStrategy(in_strategy));
}
/// Feature: infer strategy for conv2d
/// Description: the in strategy is {{8, 2, 1, 1}, {}}
/// Expectation: the return strategy is {{8, 2, 1, 1}, {1, 2, 1, 1}}
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy12) {
Strategies in_strategy = {{8, 2, 1, 1}, {}};
Strategies ret = conv2d->GenerateFullStrategy(in_strategy);
Strategies expect = {{8, 2, 1, 1}, {1, 2, 1, 1}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for conv2d
/// Description: the in strategy is {{}, {1, 2, 1, 1}}
/// Expectation: the return strategy is {{1, 2, 1, 1}, {1, 2, 1, 1}}
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy13) {
Strategies in_strategy = {{}, {1, 2, 1, 1}};
Strategies ret = conv2d->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 2, 1, 1}, {1, 2, 1, 1}};
ASSERT_EQ(ret, expect);
}
} // namespace parallel
} // namespace mindspore

View File

@ -1,86 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <list>
#include <vector>
#include "common/common_test.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/ops_info/addn_info.h"
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace parallel {
class AddNInfo;
using AddNInfoPtr = std::shared_ptr<AddNInfo>;
AddNInfoPtr addn;
class TestInferStrategySameMode : public UT::Common {
public:
TestInferStrategySameMode() {}
void SetUp();
void TearDown() {}
};
void TestInferStrategySameMode::SetUp() {
RankList dev_list;
for (int32_t i = 0; i < 34; i++) {
dev_list.push_back(i);
}
RankList stage_map;
stage_map.push_back(32);
stage_map.push_back(2);
int32_t local_dev = 0;
// create a new g_device_manager
g_device_manager = std::make_shared<DeviceManager>();
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
mindspore::HashMap<std::string, ValuePtr> attr;
Shapes inputs_shape = {{32, 64, 96}, {32, 64, 96}};
Shapes outputs_shape = {{32, 64, 96}};
addn = std::make_shared<AddNInfo>("addn_info", inputs_shape, outputs_shape, attr);
}
/// Feature: infer strategy for same mode
/// Description: the in strategy is {{2, 4, 4}, {}}
/// Expectation: the return strategy is {{2, 4, 4}, {2, 4, 4}}
TEST_F(TestInferStrategySameMode, GenerateFullStrategy1) {
Strategies in_strategy = {{2, 4, 4}, {}};
Strategies ret = addn->GenerateFullStrategy(in_strategy);
Strategies expect = {{2, 4, 4}, {2, 4, 4}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for same mode
/// Description: the in strategy is {{}, {2, 4, 4}}
/// Expectation: the return strategy is {{2, 4, 4}, {2, 4, 4}}
TEST_F(TestInferStrategySameMode, GenerateFullStrategy2) {
Strategies in_strategy = {{}, {2, 4, 4}};
Strategies ret = addn->GenerateFullStrategy(in_strategy);
Strategies expect = {{2, 4, 4}, {2, 4, 4}};
ASSERT_EQ(ret, expect);
}
} // namespace parallel
} // namespace mindspore

View File

@ -648,7 +648,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
// the parameter '0' indicates that the stageId = 0, there are 1024 devices in the stage 0
ASSERT_EQ(matmul1->GenerateStrategies(0), Status::SUCCESS);
std::vector<std::shared_ptr<StrategyWithCost>> sc = matmul1->GetStrategyCost();
for (const auto& swc : sc) {
for (const auto &swc : sc) {
StrategyPtr sp = swc->strategy_ptr;
Cost cost = *(swc->cost_list[0]);
matmul1->InitForCostModel(sp, nullptr);
@ -659,82 +659,5 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
break;
}
}
/// Feature: infer strategy for inputs_shape: {{2, 4, 8, 16}, {2, 4, 16, 32}}, transpose_b=false
/// Description: the in strategy is {{2, 4, 8, 16}, {}}
/// Expectation: the return strategy is {{2, 4, 8, 16}, {2, 4, 16, 1}}
TEST_F(TestMatmulInfo, GenerateFullStrategy1) {
Strategies in_strategy = {{2, 4, 8, 16}, {}};
Strategies ret = matmul1->GenerateFullStrategy(in_strategy);
Strategies expect = {{2, 4, 8, 16}, {2, 4, 16, 1}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for inputs_shape: {{2, 4, 8, 16}, {2, 4, 16, 32}}, transpose_b=false
/// Description: the in strategy is {{}, {2, 4, 8, 16}}
/// Expectation: the return strategy is {{2, 4, 1, 8}, {2, 4, 8, 16}}
TEST_F(TestMatmulInfo, GenerateFullStrategy2) {
Strategies in_strategy = {{}, {2, 4, 8, 16}};
Strategies ret = matmul1->GenerateFullStrategy(in_strategy);
Strategies expect = {{2, 4, 1, 8}, {2, 4, 8, 16}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for inputs_shape: {{2, 4, 8, 16}, {32, 16}}, transpose_b=true
/// Description: the in strategy is {{2, 4, 8, 16}, {}}
/// Expectation: the return strategy is {{2, 4, 8, 16}, {1, 16}}
TEST_F(TestMatmulInfo, GenerateFullStrategy3) {
Strategies in_strategy = {{2, 4, 8, 16}, {}};
Strategies ret = matmul2->GenerateFullStrategy(in_strategy);
Strategies expect = {{2, 4, 8, 16}, {1, 16}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for inputs_shape: {{2, 4, 8, 16}, {32, 16}}, transpose_b=true
/// Description: the in strategy is {{}, {8, 16}}
/// Expectation: the return strategy is {{1, 1, 1, 16}, {8, 16}}
TEST_F(TestMatmulInfo, GenerateFullStrategy4) {
Strategies in_strategy = {{}, {8, 16}};
Strategies ret = matmul2->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 1, 1, 16}, {8, 16}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for inputs_shape: {{8, 16}, {2, 4, 32, 16}}, transpose_b=true
/// Description: the in strategy is {{8, 16}, {}}
/// Expectation: the return strategy is {{8, 16}, {1, 1, 1, 16}}
TEST_F(TestMatmulInfo, GenerateFullStrategy5) {
Strategies in_strategy = {{8, 16}, {}};
Strategies ret = matmul3->GenerateFullStrategy(in_strategy);
Strategies expect = {{8, 16}, {1, 1, 1, 16}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for inputs_shape: {{8, 16}, {2, 4, 32, 16}}, transpose_b=true
/// Description: the in strategy is {{}, {2, 4, 8, 16}}
/// Expectation: the return strategy is {{1, 16}, {2, 4, 8, 16}}
TEST_F(TestMatmulInfo, GenerateFullStrategy6) {
Strategies in_strategy = {{}, {2, 4, 8, 16}};
Strategies ret = matmul3->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 16}, {2, 4, 8, 16}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for inputs_shape: {{1024, 128}, {128, 256}}, transpose_b=false
/// Description: the in strategy is {{}, {}}
/// Expectation: the return strategy is {{1024, 1}, {1, 1}}
TEST_F(TestMatmulInfo, GenerateFullStrategy7) {
Strategies in_strategy = {{}, {}};
Strategies ret = matmul5->GenerateFullStrategy(in_strategy);
Strategies expect = {{1024, 1}, {1, 1}};
ASSERT_EQ(ret, expect);
}
} // namespace parallel
} // namespace mindspore

View File

@ -205,27 +205,5 @@ TEST_F(TestOneHotInfo, CheckStrategy1) {
Status ret = onehot_info->Init(strategy, nullptr);
ASSERT_EQ(ret, FAILED);
}
/// Feature: infer strategy for inputs_shape: {{64}, {}, {}}
/// Description: the in strategy is {{8}, {}, {}}
/// Expectation: the return strategy is {{8, 1}, {}, {}}
TEST_F(TestOneHotInfo, GenerateFullStrategy1) {
Strategies in_strategy = {{8}, {}, {}};
Strategies ret = onehot_info->GenerateFullStrategy(in_strategy);
Strategies expect = {{8, 1}, {}, {}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for inputs_shape: {{64}, {}, {}}
/// Description: the in strategy is {{}, {}, {}}
/// Expectation: the return strategy is {{8, 1}, {}, {}}
TEST_F(TestOneHotInfo, GenerateFullStrategy2) {
Strategies in_strategy = {{}, {}, {}};
Strategies ret = onehot_info->GenerateFullStrategy(in_strategy);
Strategies expect = {{8, 1}, {}, {}};
ASSERT_EQ(ret, expect);
}
} // namespace parallel
} // namespace mindspore

View File

@ -272,49 +272,5 @@ TEST_F(TestPReLUInfo, AutoStrategy_2d1) {
ASSERT_EQ(stra1[0], 1);
}
}
/// Feature: infer strategy for inputs_shape: {{64, 4, 8, 16}, {4}}
/// Description: the in strategy is {{2, 4, 8, 16}, {}}
/// Expectation: the return strategy is {{2, 4, 8, 16}, {4}}
TEST_F(TestPReLUInfo, GenerateFullStrategy1) {
Strategies in_strategy = {{2, 4, 8, 16}, {}};
Strategies ret = prelu->GenerateFullStrategy(in_strategy);
Strategies expect = {{2, 4, 8, 16}, {4}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for inputs_shape: {{64, 4, 8, 16}, {4}}
/// Description: the in strategy is {{}, {4}}
/// Expectation: the return strategy is {{1, 4, 1, 1}, {4}}
TEST_F(TestPReLUInfo, GenerateFullStrategy2) {
Strategies in_strategy = {{}, {4}};
Strategies ret = prelu->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 4, 1, 1}, {4}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for inputs_shape: {{1024, 4}, {4}}
/// Description: the in strategy is {{8, 4}, {}}
/// Expectation: the return strategy is {{8, 4}, {4}}
TEST_F(TestPReLUInfo, GenerateFullStrategy3) {
Strategies in_strategy = {{8, 4}, {}};
Strategies ret = prelu_2d->GenerateFullStrategy(in_strategy);
Strategies expect = {{8, 4}, {4}};
ASSERT_EQ(ret, expect);
}
/// Feature: infer strategy for inputs_shape: {{1024, 4}, {4}}
/// Description: the in strategy is {{}, {4}}
/// Expectation: the return strategy is {{1, 4}, {4}}
TEST_F(TestPReLUInfo, GenerateFullStrategy4) {
Strategies in_strategy = {{}, {4}};
Strategies ret = prelu_2d->GenerateFullStrategy(in_strategy);
Strategies expect = {{1, 4}, {4}};
ASSERT_EQ(ret, expect);
}
} // namespace parallel
} // namespace mindspore