forked from mindspore-Ecosystem/mindspore
!45631 Insert in_strategy between input and target op
Merge pull request !45631 from liuluobin/shard_in_stra_identity
This commit is contained in:
commit
6733c6a529
|
@ -40,10 +40,7 @@ class AddNInfo : public OperatorInfo {
|
|||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override {
|
||||
infer_strategy_mode_ = SAME_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
|
|
|
@ -40,10 +40,7 @@ class ArithmeticBase : public OperatorInfo {
|
|||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override {
|
||||
infer_strategy_mode_ = BROADCAST_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
|
|
|
@ -74,7 +74,6 @@ Status BatchNormInfo::GetAttrs() {
|
|||
MS_LOG(INFO) << name_ << ": The is_traing is " << is_training_ << ", epsilon is " << epsilon_ << ", momentum is "
|
||||
<< momentum_ << ", data format is " << format_;
|
||||
|
||||
infer_strategy_mode_ = INDIVIDUAL_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -335,38 +334,6 @@ std::vector<StrategyPtr> BatchNormInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
// in_strategy: ((N, C, H, W), (), (), (), ()) return: ((N, C, H, W), (C), (C), (C), (C))
|
||||
// in_strategy: ((), (C), (C), (C), (C)) return: ((1, C, 1, 1), (C), (C), (C), (C))
|
||||
// in_strategy: ((), (C), (), (C), (C)) throw exception
|
||||
Shapes BatchNormInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 5) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in strategy must be 5, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if ((in_strategy[2] != in_strategy[1]) || (in_strategy[3] != in_strategy[1]) || (in_strategy[4] != in_strategy[1])) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The last 4 strategy must be equal, but got " << in_strategy;
|
||||
}
|
||||
|
||||
Shape channel_strategy;
|
||||
if (!in_strategy[0].empty()) {
|
||||
if (in_strategy[0].size() != 4 && in_strategy[0].size() != 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be 4 or 2, but got " << in_strategy[0].size();
|
||||
}
|
||||
channel_strategy = {in_strategy[0][1]};
|
||||
return Shapes({in_strategy[0], channel_strategy, channel_strategy, channel_strategy, channel_strategy});
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < in_strategy.size(); ++i) {
|
||||
if (!in_strategy[i].empty()) {
|
||||
channel_strategy = in_strategy[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Shape tmp_strategy(inputs_shape_[0].size(), 1);
|
||||
tmp_strategy[1] = channel_strategy[0];
|
||||
return Shapes({tmp_strategy, channel_strategy, channel_strategy, channel_strategy, channel_strategy});
|
||||
}
|
||||
REGISTER(BatchNormInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,7 +47,6 @@ class BatchNormInfo : public OperatorInfo {
|
|||
Status InferTensorMap() override;
|
||||
void InferReplaceOps() override;
|
||||
Status InferAsLossDivisor() override;
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
|
||||
private:
|
||||
bool is_training_ = false;
|
||||
|
|
|
@ -102,30 +102,6 @@ std::vector<StrategyPtr> BiasAddInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
// in_strategy: ((N, C, H, W), ()), inputs shapes: ((n, c, h, w), (c)), return: ((N, C, H, W), (C))
|
||||
// in_strategy: ((), (C)), inputs shapes: ((n, c, h, w), (c)), return: ((1, C, 1, 1), (C))
|
||||
// in_strategy: ((N, C), ()), inputs shapes: ((n, c), (c)), return: ((N, C), (C))
|
||||
// in_strategy: ((), (C)), inputs shapes: ((n, c), (c)), return: ((1, C), (C))
|
||||
Shapes BiasAddInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 2, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if (!in_strategy[0].empty()) {
|
||||
if (in_strategy[0].size() != 2 && in_strategy[0].size() != 4) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be 2 or 4, but got " << in_strategy[0].size();
|
||||
}
|
||||
return Shapes({in_strategy[0], {in_strategy[0][1]}});
|
||||
}
|
||||
|
||||
if (in_strategy[1].size() != 1) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[1] must be 1, but got " << in_strategy[1].size();
|
||||
}
|
||||
Shape tmp(inputs_shape_[0].size(), 1);
|
||||
tmp[1] = in_strategy[1][0];
|
||||
return Shapes({tmp, in_strategy[1]});
|
||||
}
|
||||
|
||||
REGISTER(BiasAddInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,15 +42,11 @@ class BiasAddInfo : public OperatorInfo {
|
|||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override {
|
||||
infer_strategy_mode_ = INDIVIDUAL_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,10 +42,7 @@ class BoundingBoxEncodeInfo : public OperatorInfo {
|
|||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status GetAttrs() override {
|
||||
infer_strategy_mode_ = SAME_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status InferForwardCommunication() override { return SUCCESS; };
|
||||
|
||||
private:
|
||||
|
|
|
@ -28,7 +28,6 @@ Status CdistInfo::GetAttrs() {
|
|||
MS_LOG(ERROR) << "Dimension of each input must be 2 or 3, but got dimension is " << input_dims_ << ".";
|
||||
return FAILED;
|
||||
}
|
||||
infer_strategy_mode_ = INDIVIDUAL_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -117,29 +116,6 @@ void CdistInfo::ReComputeBatchSplitFlagList() {
|
|||
}
|
||||
}
|
||||
|
||||
// in_strategy: ((B, P, 1), ()), inputs shape: ((b, p, m), (b, r, m)), return: ((B, P, 1), (B, 1, 1))
|
||||
// in_strategy: ((), (B, R, 1)), inputs shape: ((b, p, m), (b, r, m)), return: ((B, 1, 1), (B, R, 1))
|
||||
Shapes CdistInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 2, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
Shape tmp_strategy;
|
||||
if (!in_strategy[0].empty()) {
|
||||
tmp_strategy = Shape(inputs_shape_[1].size(), 1);
|
||||
tmp_strategy[0] = in_strategy[0][0];
|
||||
return Shapes({in_strategy[0], tmp_strategy});
|
||||
}
|
||||
|
||||
if (!in_strategy[1].empty()) {
|
||||
tmp_strategy = Shape(inputs_shape_[0].size(), 1);
|
||||
tmp_strategy[0] = in_strategy[1][0];
|
||||
return Shapes({tmp_strategy, in_strategy[1]});
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";
|
||||
}
|
||||
|
||||
REGISTER(CdistInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -43,7 +43,6 @@ class CdistInfo : public OperatorInfo {
|
|||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
|
||||
private:
|
||||
size_t input_dims_ = 0;
|
||||
|
|
|
@ -56,7 +56,6 @@ Status ConcatInfo::GetAttrs() {
|
|||
}
|
||||
|
||||
axis_ = LongToSize(axis);
|
||||
infer_strategy_mode_ = SAME_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -134,7 +134,6 @@ Status Conv2DInfo::GetAttrsBase() {
|
|||
// group
|
||||
group_ = GetIntAttr(GROUP);
|
||||
|
||||
infer_strategy_mode_ = INDIVIDUAL_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -1039,68 +1038,6 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
// conv2d: ((N, C-in, H, W), (C-out, C-in, k1, k2))
|
||||
// in_strategy: ((N, C, H, W), ()), return: ((N, C, H, W), (1, C, 1, 1))
|
||||
// in_strategy: ((), (C0, C1, 1, 1)), return: ((1, C1, 1, 1), (C0, C1, 1, 1))
|
||||
Shapes Conv2DInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 2, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
Shape tmp_strategy;
|
||||
if (!in_strategy[0].empty()) {
|
||||
if (in_strategy[0].size() != inputs_shape_[0].size()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be " << inputs_shape_[0].size() << ", but got "
|
||||
<< in_strategy[0].size();
|
||||
}
|
||||
tmp_strategy = Shape(inputs_shape_[1].size(), 1);
|
||||
tmp_strategy[1] = in_strategy[0][1];
|
||||
return Shapes({in_strategy[0], tmp_strategy});
|
||||
}
|
||||
|
||||
if (!in_strategy[1].empty()) {
|
||||
if (in_strategy[1].size() != inputs_shape_[1].size()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[1] must be " << inputs_shape_[1].size() << ", but got "
|
||||
<< in_strategy[1].size();
|
||||
}
|
||||
tmp_strategy = Shape(inputs_shape_[0].size(), 1);
|
||||
tmp_strategy[1] = in_strategy[1][1];
|
||||
return Shapes({tmp_strategy, in_strategy[1]});
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";
|
||||
}
|
||||
|
||||
// conv2d-transpose: ((N, C-out, H, W), (C-out, C-in, k1, k2))
|
||||
// in_strategy: ((N, C, H, W), ()), return: ((N, C, H, W), (C, 1, 1, 1))
|
||||
// in_strategy: ((), (C0, C1, 1, 1)), return: ((1, C0, 1, 1), (C0, C1, 1, 1))
|
||||
Shapes Conv2DBackpropInputInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 2, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
Shape tmp_strategy;
|
||||
if (!in_strategy[0].empty()) {
|
||||
if (in_strategy[0].size() != 4) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be 4, but got " << in_strategy[0].size();
|
||||
}
|
||||
tmp_strategy = Shape(inputs_shape_[1].size(), 1);
|
||||
tmp_strategy[0] = in_strategy[0][1];
|
||||
return Shapes({in_strategy[0], tmp_strategy});
|
||||
}
|
||||
|
||||
if (!in_strategy[1].empty()) {
|
||||
if (in_strategy[1].size() != 4) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[1] must be 4, but got " << in_strategy[1].size();
|
||||
}
|
||||
tmp_strategy = Shape(inputs_shape_[0].size(), 1);
|
||||
tmp_strategy[1] = in_strategy[1][0];
|
||||
return Shapes({tmp_strategy, in_strategy[1]});
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";
|
||||
}
|
||||
|
||||
Status Conv2DBackpropInputInfo::GetOutShape() {
|
||||
if (input_value_.size() != 3) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of input value must be 3, but got " << input_value_.size();
|
||||
|
|
|
@ -146,7 +146,6 @@ class Conv2DInfo : public OperatorInfo {
|
|||
virtual int64_t ComputeOverlapBottomSizeByRankBias(int64_t rank_bias);
|
||||
virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
|
||||
virtual int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
|
||||
private:
|
||||
Status CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy);
|
||||
|
@ -179,7 +178,6 @@ class Conv2DBackpropInputInfo : public Conv2DInfo {
|
|||
int64_t ComputeOverlapBottomSizeByRankBias(int64_t rank_bias) override;
|
||||
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) override;
|
||||
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias) override;
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
|
||||
private:
|
||||
Shape out_shape_;
|
||||
|
|
|
@ -51,7 +51,6 @@ Status GammaInfo::GetAttrs() {
|
|||
MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_;
|
||||
return FAILED;
|
||||
}
|
||||
infer_strategy_mode_ = INDEPENDENT_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -167,7 +167,6 @@ Status GatherInfo::GetAttrs() {
|
|||
dynamic_shape_indices_ = true;
|
||||
}
|
||||
#endif
|
||||
infer_strategy_mode_ = INDEPENDENT_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -54,7 +54,6 @@ Status GatherDInfo::GetAttrs() {
|
|||
|
||||
dim_ = LongToSize(dim);
|
||||
MS_LOG(INFO) << name_ << ": The dim is " << dim_;
|
||||
infer_strategy_mode_ = SAME_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -41,10 +41,7 @@ class GatherNdInfo : public OperatorInfo {
|
|||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override {
|
||||
infer_strategy_mode_ = INDEPENDENT_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
|
|
|
@ -38,10 +38,7 @@ class InplaceOpBase : public OperatorInfo {
|
|||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override {
|
||||
infer_strategy_mode_ = SAME_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
|
|
|
@ -36,10 +36,7 @@ class IOUInfo : public OperatorInfo {
|
|||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override {
|
||||
infer_strategy_mode_ = INDEPENDENT_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
|
|
|
@ -25,7 +25,6 @@ namespace mindspore {
|
|||
namespace parallel {
|
||||
Status KLDivLossInfo::GetAttrs() {
|
||||
reduction_ = GetStringAttr(REDUCTION);
|
||||
infer_strategy_mode_ = SAME_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -59,7 +59,6 @@ Status LayerNormInfo::GetAttrs() {
|
|||
axis = axis + dim;
|
||||
}
|
||||
begin_norm_axis_ = LongToSize(axis);
|
||||
infer_strategy_mode_ = INDIVIDUAL_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -249,67 +248,6 @@ Status LayerNormInfo::InitShapes() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
// in_strategy: ((A, B, 1, 1), (), ()), begin_norm_axis_: 2, Shapes: ((a, b, c, d), (b, c, d), (b, c, d))
|
||||
// return: ((A, B, 1, 1), (B, 1, 1), (B, 1, 1))
|
||||
// in_strategy: ((), (B, 1, 1), (B, 1, 1)), begin_norm_axis_: 2, Shapes: ((a, b, c, d), (b, c, d), (b, c, d))
|
||||
// return: ((1, B, 1, 1), (B, 1, 1), (B, 1, 1))
|
||||
Shapes LayerNormInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if (in_strategy[1] != in_strategy[2]) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[1] must be equal to the in_strategy[2], but the in_strategy[1] is"
|
||||
<< in_strategy[1] << ", the in_strategy[2] is " << in_strategy[2];
|
||||
}
|
||||
|
||||
if (gamma_shape_ != beta_shape_) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The gamma's shape must be equal to the beta's shape, but the gamma's shape is"
|
||||
<< gamma_shape_ << ", the beta's shape is " << beta_shape_;
|
||||
}
|
||||
|
||||
if (input_shape_.size() < gamma_shape_.size()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< name_ << ": The input's shape size cannot smaller than gamma's shape size, but the input's shape size is"
|
||||
<< input_shape_.size() << ", the gamma's shape size is " << gamma_shape_.size();
|
||||
}
|
||||
|
||||
if (!in_strategy[0].empty()) {
|
||||
if (in_strategy[0].size() != input_shape_.size()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< name_
|
||||
<< ": The size of in_strategy[0] must be equal to the size of inputs_shape[0], but the in_strategy[0] is"
|
||||
<< in_strategy[0] << ", the inputs_shape[0] is " << input_shape_;
|
||||
}
|
||||
if (input_shape_.size() == gamma_shape_.size()) {
|
||||
return Shapes({in_strategy[0], in_strategy[0], in_strategy[0]});
|
||||
} else {
|
||||
size_t diff_len = input_shape_.size() - gamma_shape_.size();
|
||||
Shape gamma_strategy(in_strategy[0].begin() + diff_len, in_strategy[0].end());
|
||||
return Shapes({in_strategy[0], gamma_strategy, gamma_strategy});
|
||||
}
|
||||
}
|
||||
|
||||
if (!in_strategy[1].empty()) {
|
||||
if (in_strategy[1].size() != gamma_shape_.size()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< name_
|
||||
<< ": The size of in_strategy[1] must be equal to the size of inputs_shape[1], but the in_strategy[1] is"
|
||||
<< in_strategy[1] << ", the inputs_shape[1] is " << gamma_shape_;
|
||||
}
|
||||
if (input_shape_.size() == gamma_shape_.size()) {
|
||||
return Shapes({in_strategy[1], in_strategy[1], in_strategy[1]});
|
||||
} else {
|
||||
size_t diff_len = input_shape_.size() - gamma_shape_.size();
|
||||
Shape tmp_strategy = in_strategy[1];
|
||||
(void)tmp_strategy.insert(tmp_strategy.begin(), diff_len, 1);
|
||||
return Shapes({tmp_strategy, in_strategy[1], in_strategy[1]});
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0], in_strategy[1] and in_strategy[2] are empty";
|
||||
}
|
||||
|
||||
REGISTER(LayerNormInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -57,7 +57,6 @@ class LayerNormInfo : public OperatorInfo {
|
|||
Status CreateInputTensorMap(size_t input_index);
|
||||
Status GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> &sp_vector);
|
||||
Status InitShapes();
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
|
||||
private:
|
||||
size_t begin_norm_axis_;
|
||||
|
|
|
@ -65,7 +65,6 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() {
|
|||
MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
|
||||
return FAILED;
|
||||
}
|
||||
infer_strategy_mode_ = SAME_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -125,7 +125,6 @@ Status MatMulBase::GetAttrs() {
|
|||
<< mat_a_dimension_ << ", the dim of mat_b is " << mat_b_dimension_;
|
||||
}
|
||||
|
||||
infer_strategy_mode_ = INDIVIDUAL_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -514,54 +513,6 @@ std::shared_ptr<Strategies> BatchMatMulInfo::GenerateBatchStrategies() {
|
|||
|
||||
Status MatMulBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Shapes MatMulBase::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
// if the transpose_b is false:
|
||||
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, C, D), (A, B, D, 1))
|
||||
// in_strategy: ((), (A, B, D, E)), inputs shape: ((a, b, c, d), (a, b, d, e)), return: ((A, B, 1, D), (A, B, D, E))
|
||||
// if the transpose_b is true:
|
||||
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, C, D), (A, B, 1, D))
|
||||
// in_strategy: ((), (A, B, E, D)), inputs shape: ((a, b, c, d), (a, b, e, d)), return: ((A, B, 1, D), (A, B, E, D))
|
||||
if (in_strategy.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in strategy must be 2, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if (in_strategy[0].empty() && in_strategy[1].empty()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in strategy is empty";
|
||||
}
|
||||
|
||||
if (!in_strategy[0].empty() && !in_strategy[1].empty()) {
|
||||
return in_strategy;
|
||||
}
|
||||
|
||||
if (!in_strategy[0].empty() && in_strategy[0].size() != inputs_shape_[0].size()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] is " << in_strategy[0].size()
|
||||
<< ", but the size of the inputs_shape_[0] is " << inputs_shape_[0].size();
|
||||
}
|
||||
|
||||
if (!in_strategy[1].empty() && in_strategy[1].size() != inputs_shape_[1].size()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[1] is " << in_strategy[1].size()
|
||||
<< ", but the size of the inputs_shape_[1] is " << inputs_shape_[1].size();
|
||||
}
|
||||
|
||||
Shapes ret_strategy = InferStrategyBroadcastMode(in_strategy);
|
||||
if (transpose_b_) {
|
||||
if (in_strategy[0].empty()) {
|
||||
ret_strategy[0][ret_strategy[0].size() - 2] = 1;
|
||||
} else {
|
||||
ret_strategy[1][ret_strategy[1].size() - 2] = 1;
|
||||
}
|
||||
} else {
|
||||
if (in_strategy[0].empty()) {
|
||||
ret_strategy[0][ret_strategy[0].size() - 2] = 1;
|
||||
ret_strategy[0][ret_strategy[0].size() - 1] = ret_strategy[1][ret_strategy[1].size() - 2];
|
||||
} else {
|
||||
ret_strategy[1][ret_strategy[1].size() - 2] = ret_strategy[0][ret_strategy[0].size() - 1];
|
||||
ret_strategy[1][ret_strategy[1].size() - 1] = 1;
|
||||
}
|
||||
}
|
||||
return ret_strategy;
|
||||
}
|
||||
|
||||
// PCL matmul
|
||||
ReplaceGraphPtr MatMul::replace_graph(const CNodePtr &cnode) {
|
||||
if (!candidate_flag_) {
|
||||
|
|
|
@ -48,7 +48,6 @@ class MatMulBase : public OperatorInfo {
|
|||
Status InferTensorMap() override;
|
||||
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
|
||||
Status GetAttrs() override;
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
|
||||
bool candidate_flag_ = false;
|
||||
bool transpose_a_ = false;
|
||||
|
|
|
@ -47,7 +47,6 @@ Status OneHotInfo::GetAttrs() {
|
|||
MS_LOG(ERROR) << name_ << ": Axis " << axis_ << " is out of range[-1, 1].";
|
||||
return FAILED;
|
||||
}
|
||||
infer_strategy_mode_ = INDIVIDUAL_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -248,24 +247,6 @@ std::shared_ptr<Strategies> OneHotInfo::GenerateBatchStrategies() {
|
|||
return std::make_shared<Strategies>(strategy_v);
|
||||
}
|
||||
|
||||
Shapes OneHotInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if (in_strategy[0].empty()) {
|
||||
Shape strategy = {stage_device_size_, 1};
|
||||
Shape empty_strategy;
|
||||
return Shapes({strategy, empty_strategy, empty_strategy});
|
||||
}
|
||||
|
||||
if (in_strategy[0].size() != 1) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be 1, but got " << in_strategy[0].size();
|
||||
}
|
||||
|
||||
return Shapes({{in_strategy[0][0], 1}, {}, {}});
|
||||
}
|
||||
|
||||
REGISTER(OneHotInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,7 +49,6 @@ class OneHotInfo : public OperatorInfo {
|
|||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status ExtractInputInfo();
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
|
||||
private:
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
|
|
|
@ -2150,137 +2150,6 @@ float OperatorInfo::GetFloatAttr(const std::string &attr_name) {
|
|||
return attr_iter->second->cast<FP32ImmPtr>()->value();
|
||||
}
|
||||
|
||||
// in_strategy: ((A, B, C, D), ()), return: ((A, B, C, D), (A, B, C, D))
|
||||
// in_strategy: ((), (A, B, C, D)), return: ((A, B, C, D), (A, B, C, D))
|
||||
Shapes OperatorInfo::InferStrategySameMode(const Shapes &in_strategy) const {
|
||||
Shape value;
|
||||
for (auto &ele : in_strategy) {
|
||||
if (!ele.empty()) {
|
||||
value = ele;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return Shapes(in_strategy.size(), value);
|
||||
}
|
||||
|
||||
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (e, f, g)), return: ((A, B, C, D), (1, 1, 1))
|
||||
// in_strategy: ((), (E, F, G)), inputs shape: ((a, b, c, d), (e, f, g)), return: ((1, 1, 1, 1), (E, F, G))
|
||||
Shapes OperatorInfo::InferStrategyIndependentMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != inputs_shape_.size()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< name_ << ": The size of strategies must be equal to the size of inputs shape, but the size of strategies is "
|
||||
<< in_strategy.size() << ", and the size of inputs shape is " << inputs_shape_.size();
|
||||
}
|
||||
|
||||
Shapes ret;
|
||||
for (size_t i = 0; i < in_strategy.size(); ++i) {
|
||||
if (in_strategy[i].empty()) {
|
||||
(void)ret.emplace_back(Shape(inputs_shape_[i].size(), 1));
|
||||
continue;
|
||||
}
|
||||
(void)ret.emplace_back(in_strategy[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (a, b, c, d)), return: ((A, B, C, D), (A, B, C, D))
|
||||
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (b, c, d)), return: ((A, B, C, D), (B, C, D))
|
||||
// in_strategy: ((), (B, C, D)), inputs shape: ((a, b, c, d), (b, c, d)), return: ((1, B, C, D), (B, C, D))
|
||||
// in_strategy: ((A, B, C, D), ()), inputs shape: ((a, b, c, d), (1, c, d)), return: ((A, B, C, D), (1, C, D))
|
||||
Shapes OperatorInfo::InferStrategyBroadcastMode(const Shapes &in_strategy) {
|
||||
Shapes ret = InferStrategySameMode(in_strategy);
|
||||
if (ret.size() != inputs_shape_.size()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< name_ << ": The size of strategies must be equal to the size of inputs shape, but the size of strategies is "
|
||||
<< ret.size() << ", and the size of inputs shape is " << inputs_shape_.size();
|
||||
}
|
||||
|
||||
// handle the broadcast
|
||||
// alignment length
|
||||
for (size_t i = 0; i < ret.size(); ++i) {
|
||||
size_t strategy_size = ret[i].size();
|
||||
size_t shape_size = inputs_shape_[i].size();
|
||||
size_t diff_len = strategy_size > shape_size ? strategy_size - shape_size : shape_size - strategy_size;
|
||||
if (strategy_size > shape_size) {
|
||||
// strategy is (A, B, C, D), and shape is (c, d) -> updated strategy is (C, D)
|
||||
(void)ret[i].erase(ret[i].begin(), ret[i].begin() + static_cast<different_type>(diff_len));
|
||||
} else if (strategy_size < shape_size) {
|
||||
// strategy is (C, D), and shape is (a, b, c, d) -> updated strategy is (1, 1, C, D)
|
||||
(void)ret[i].insert(ret[i].begin(), diff_len, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// handle the 1 shape value
|
||||
for (size_t i = 0; i < ret.size(); ++i) {
|
||||
// strategy is (A, B, C, D), and shape is (1, b, c, d) -> updated strategy is (1, B, C, D)
|
||||
for (size_t j = 0; j < ret[i].size(); ++j) {
|
||||
if (inputs_shape_[i][j] == 1) {
|
||||
ret[i][j] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Shapes OperatorInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in strategy is " << in_strategy << ", need to override this function ";
|
||||
}
|
||||
|
||||
Shapes OperatorInfo::GenerateFullStrategyBase(const Shapes &in_strategy) {
|
||||
// there is no empty in the in_strategy
|
||||
auto it = std::find_if(in_strategy.begin(), in_strategy.end(), [](const Shape &ele) { return ele.empty(); });
|
||||
if (it == in_strategy.end()) {
|
||||
MS_LOG(INFO) << name_ << ": There is no empty in the input strategy, return to itself: " << in_strategy;
|
||||
return in_strategy;
|
||||
}
|
||||
|
||||
// the in_strategy are all empty, generate the data parallel strategy
|
||||
auto item = std::find_if(in_strategy.begin(), in_strategy.end(), [](const Shape &ele) { return !ele.empty(); });
|
||||
if (item == in_strategy.end()) {
|
||||
std::shared_ptr<Strategies> dp_strategy_ptr = GenerateBatchStrategies();
|
||||
MS_EXCEPTION_IF_NULL(dp_strategy_ptr);
|
||||
MS_LOG(INFO) << name_
|
||||
<< ": The in strategy are all empty, generate the data parallel strategy: " << *dp_strategy_ptr;
|
||||
return *dp_strategy_ptr;
|
||||
}
|
||||
|
||||
// generate the full strategy from non empty strategy
|
||||
Shapes ret;
|
||||
switch (infer_strategy_mode_) {
|
||||
case SAME_MODE:
|
||||
ret = InferStrategySameMode(in_strategy);
|
||||
break;
|
||||
case BROADCAST_MODE:
|
||||
ret = InferStrategyBroadcastMode(in_strategy);
|
||||
break;
|
||||
case INDEPENDENT_MODE:
|
||||
ret = InferStrategyIndependentMode(in_strategy);
|
||||
break;
|
||||
case INDIVIDUAL_MODE:
|
||||
ret = InferStrategyIndividualMode(in_strategy);
|
||||
break;
|
||||
case INVALID_MODE:
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << name_ << ": The invalid mode for infer strategy";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The origin strategy is " << in_strategy << ", and the return strategy is " << ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
Shapes OperatorInfo::GenerateFullStrategy(const Shapes &in_strategy) {
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
||||
}
|
||||
Shapes ret = GenerateFullStrategyBase(in_strategy);
|
||||
StrategyPtr strategy_ptr = NewStrategy(0, ret);
|
||||
if (CheckStrategy(strategy_ptr) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Invalid strategy";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<ValuePtr> GetValueSequence(const ValuePtr &sequence) {
|
||||
MS_EXCEPTION_IF_NULL(sequence);
|
||||
std::vector<ValuePtr> ret;
|
||||
|
|
|
@ -103,7 +103,6 @@ class OperatorInfo {
|
|||
const OperatorCostPtr &operator_cost() const { return operator_cost_; }
|
||||
void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; }
|
||||
virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0;
|
||||
Shapes GenerateFullStrategy(const Shapes &in_strategy);
|
||||
|
||||
virtual std::shared_ptr<Strategies> GenerateBatchStrategies();
|
||||
virtual void ReComputeBatchSplitFlagList();
|
||||
|
@ -226,7 +225,6 @@ class OperatorInfo {
|
|||
// needed by rec_parser
|
||||
std::string type_;
|
||||
bool is_last_node_ = false;
|
||||
InferStrategyMode infer_strategy_mode_ = INVALID_MODE;
|
||||
virtual Status CheckStrategy(const StrategyPtr &strategy) = 0;
|
||||
virtual Status InferTensorMap() = 0;
|
||||
virtual Status InferForwardCommunication() = 0;
|
||||
|
@ -236,11 +234,6 @@ class OperatorInfo {
|
|||
virtual Status InferTensorInfo();
|
||||
virtual void InferReplaceOps() {}
|
||||
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
|
||||
virtual Shapes InferStrategyIndividualMode(const Shapes &in_strategy);
|
||||
Shapes GenerateFullStrategyBase(const Shapes &in_strategy);
|
||||
Shapes InferStrategySameMode(const Shapes &in_strategy) const;
|
||||
Shapes InferStrategyBroadcastMode(const Shapes &in_strategy);
|
||||
Shapes InferStrategyIndependentMode(const Shapes &in_strategy);
|
||||
Status CheckStrategyByVector(const Shapes &strategy, const Shapes &inputs_shape);
|
||||
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
|
||||
void SetRepeatedCalcDevMatrix();
|
||||
|
|
|
@ -55,7 +55,6 @@ Status StackInfo::GetAttrs() {
|
|||
axis = axis + dim;
|
||||
}
|
||||
axis_ = LongToSize(axis);
|
||||
infer_strategy_mode_ = SAME_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -89,7 +89,6 @@ Status PReLUInfo::GetAttrs() {
|
|||
<< outputs_shape_.size() << " is wrong.";
|
||||
return FAILED;
|
||||
}
|
||||
infer_strategy_mode_ = INDIVIDUAL_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -109,34 +108,6 @@ std::vector<StrategyPtr> PReLUInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
|
||||
Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
// in_strategy: ((A, B, C), ()), Shapes: ((a, b, c), (b)), return: ((A, B, C), (B))
|
||||
// in_strategy: ((A, B, C), ()), Shapes: ((a, b, c), (1)), return: ((A, B, C), (1))
|
||||
// in_strategy: ((), (B)), Shapes: ((a, b, c), (b)), return: ((1, B, 1), (B))
|
||||
Shapes PReLUInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if (!in_strategy[0].empty()) {
|
||||
if (in_strategy[0].size() < 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be larger than 1, but got "
|
||||
<< in_strategy[0].size();
|
||||
}
|
||||
if (inputs_shape_[1][0] > 1) {
|
||||
return Shapes({in_strategy[0], {in_strategy[0][1]}});
|
||||
} else {
|
||||
return Shapes({in_strategy[0], {1}});
|
||||
}
|
||||
}
|
||||
|
||||
if (!in_strategy[1].empty()) {
|
||||
Shape tmp(inputs_shape_[0].size(), 1);
|
||||
tmp[1] = in_strategy[1][0];
|
||||
return Shapes({tmp, in_strategy[1]});
|
||||
}
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";
|
||||
}
|
||||
|
||||
REGISTER(PReLUInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,7 +47,6 @@ class PReLUInfo : public OperatorInfo {
|
|||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status GetAttrs() override;
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
|
||||
private:
|
||||
Dimensions input_strategy_;
|
||||
|
|
|
@ -29,7 +29,6 @@ Status RandomChoiceWithMaskInfo::GetAttrs() {
|
|||
if (attrs_.find(SEED2) != attrs_.end()) {
|
||||
seed2_ = GetValue<int64_t>(attrs_[SEED2]);
|
||||
}
|
||||
infer_strategy_mode_ = INDEPENDENT_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,6 @@ Status ROIAlignInfo::GetAttrs() {
|
|||
}
|
||||
(void)roi_align_attrs.emplace_back(std::make_pair(attr_key, attrs_[attr_key]));
|
||||
}
|
||||
infer_strategy_mode_ = INDEPENDENT_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -169,27 +169,6 @@ void ScatterNdOpsInfo::ReComputeBatchSplitFlagList() {
|
|||
}
|
||||
}
|
||||
|
||||
// in_strategy: ((A, B, C, D), (), ()), Shapes: ((a, b, c, d), (e, f), (e, f, c, d))
|
||||
// return: ((A, B, C, D), (1, 1), (1, 1, C, D))
|
||||
// return: throw exception
|
||||
Shapes ScatterNdOpsInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != SIZE_THREE) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if (in_strategy[INDEX_ZERO].empty()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] must be not empty, you should set the first input strategy.";
|
||||
}
|
||||
|
||||
Shape indices_strategy, updates_strategy;
|
||||
indices_strategy = Shape(inputs_shape_[INDEX_ONE].size(), 1);
|
||||
updates_strategy = Shape(inputs_shape_[INDEX_TWO].size(), 1);
|
||||
for (size_t i = 0; i < in_strategy[0].size() - gather_dims_size_; ++i) {
|
||||
updates_strategy[indices_strategy.size() - 1 + i] = in_strategy[INDEX_ZERO][i + gather_dims_size_];
|
||||
}
|
||||
return Shapes({in_strategy[INDEX_ZERO], indices_strategy, updates_strategy});
|
||||
}
|
||||
|
||||
ReplaceGraphPtr ScatterNdOpsInfo::replace_graph(const CNodePtr &cnode) {
|
||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << " replace graph failed";
|
||||
|
|
|
@ -58,7 +58,6 @@ class ScatterNdOpsInfo : public OperatorInfo {
|
|||
int64_t slice_size_ = 0;
|
||||
size_t gather_dims_size_ = 1;
|
||||
GenerateGraph gen_g_ = GenerateGraph(attrs_);
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
};
|
||||
|
||||
using ScatterNdOpsInfoPtr = std::shared_ptr<ScatterNdOpsInfo>;
|
||||
|
|
|
@ -176,46 +176,6 @@ std::vector<StrategyPtr> ScatterOpsInfo::GenerateOpStrategies(int64_t stage_id)
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
// in_strategy: ((A, B, C, D), (), ()), Shapes: ((a, b, c, d), (e, f), (e, f, b, c, d))
|
||||
// return: ((A, B, C, D), (1, 1), (1, 1, B, C, D))
|
||||
// in_strategy: ((), (1, 1), (E, F, B, C, D)), Shapes: ((a, b, c, d), (e, f), (e, f, b, c, d))
|
||||
// return: ((A, B, C, D), (1, 1), (E, F, B, C, D))
|
||||
// in_strategy: ((), (), (E, F, B, C, D)), Shapes: ((a, b, c, d), (e, f), (e, f, b, c, d))
|
||||
// return: throw exception
|
||||
Shapes ScatterOpsInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if (in_strategy[1].empty() != in_strategy[2].empty()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< name_
|
||||
<< ": The in_strategy[1] and in_strategy[2] must be all empty or all non empty, but the in_strategy[1] is "
|
||||
<< in_strategy[1] << ", the in_strategy[2] is " << in_strategy[2];
|
||||
}
|
||||
|
||||
Shape x_strategy, indices_strategy, updates_strategy;
|
||||
indices_strategy = Shape(inputs_shape_[1].size(), 1);
|
||||
if (!in_strategy[0].empty()) {
|
||||
updates_strategy = in_strategy[0];
|
||||
(void)updates_strategy.erase(updates_strategy.begin());
|
||||
(void)updates_strategy.insert(updates_strategy.begin(), inputs_shape_[1].size(), 1);
|
||||
return Shapes({in_strategy[0], indices_strategy, updates_strategy});
|
||||
}
|
||||
|
||||
if (!in_strategy[2].empty()) {
|
||||
if (std::accumulate(in_strategy[1].begin(), in_strategy[1].end(), 1, std::multiplies<int64_t>()) != 1) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[1] must be fill with 1, but got " << in_strategy[1];
|
||||
}
|
||||
x_strategy = in_strategy[2];
|
||||
(void)x_strategy.erase(x_strategy.begin(),
|
||||
x_strategy.begin() + static_cast<different_type>(inputs_shape_[1].size() - 1));
|
||||
return Shapes({x_strategy, indices_strategy, in_strategy[2]});
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0], in_strategy[1] and in_strategy[2] are empty";
|
||||
}
|
||||
|
||||
REGISTER(ScatterUpdateInfo);
|
||||
REGISTER(ScatterMaxInfo);
|
||||
REGISTER(ScatterMinInfo);
|
||||
|
|
|
@ -41,16 +41,12 @@ class ScatterOpsInfo : public OperatorInfo {
|
|||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override {
|
||||
infer_strategy_mode_ = INDIVIDUAL_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferMirrorOps() override { return SUCCESS; } // the scatter_update only use in eval/predict
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
};
|
||||
|
||||
class ScatterUpdateInfo : public ScatterOpsInfo {
|
||||
|
|
|
@ -41,10 +41,7 @@ class SelectInfo : public OperatorInfo {
|
|||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override {
|
||||
infer_strategy_mode_ = SAME_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
|
|
|
@ -55,7 +55,6 @@ Status UniformRealInfo::GetAttrs() {
|
|||
MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_;
|
||||
return FAILED;
|
||||
}
|
||||
infer_strategy_mode_ = INDEPENDENT_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -54,7 +54,6 @@ Status UnsortedSegmentOpInfo::GetAttrs() {
|
|||
MS_LOG(ERROR) << name_ << ": the number of segments should be non negative value.";
|
||||
return FAILED;
|
||||
}
|
||||
infer_strategy_mode_ = INDIVIDUAL_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -311,32 +310,6 @@ Status UnsortedSegmentMaxInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
// in_strategy: ((A, B, C), ()), Shapes: ((a, b, c), (a, b)), return: ((A, B, C), (A, B))
|
||||
// in_strategy: ((), (A, B)), Shapes: ((a, b, c), (a, b)), return: ((A, B, 1), (A, B))
|
||||
Shapes UnsortedSegmentOpInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if (!in_strategy[0].empty()) {
|
||||
return Shapes({in_strategy[0], Shape(in_strategy[0].begin(), in_strategy[0].begin() + inputs_shape_[1].size())});
|
||||
}
|
||||
|
||||
if (!in_strategy[1].empty()) {
|
||||
Shape tmp_strategy = in_strategy[1];
|
||||
if (inputs_shape_[0].size() < inputs_shape_[1].size()) {
|
||||
MS_LOG(EXCEPTION) << name_
|
||||
<< ": The size of inputs_shape[0] can not smaller than the size of inputs_shape[1], but the "
|
||||
"size of inputs_shape[0] is "
|
||||
<< inputs_shape_[0].size() << ", the size of inputs_shape[1] is " << inputs_shape_[1].size();
|
||||
}
|
||||
size_t diff_len = inputs_shape_[0].size() - inputs_shape_[1].size();
|
||||
(void)tmp_strategy.insert(tmp_strategy.end(), diff_len, 1);
|
||||
return Shapes({tmp_strategy, in_strategy[1]});
|
||||
}
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] and in_strategy[1] are empty";
|
||||
}
|
||||
|
||||
REGISTER(UnsortedSegmentSumInfo);
|
||||
REGISTER(UnsortedSegmentProdInfo);
|
||||
REGISTER(UnsortedSegmentMinInfo);
|
||||
|
|
|
@ -57,7 +57,6 @@ class UnsortedSegmentOpInfo : public OperatorInfo {
|
|||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status GetAttrs() override;
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
};
|
||||
|
||||
class UnsortedSegmentSumInfo : public UnsortedSegmentOpInfo {
|
||||
|
|
|
@ -31,14 +31,22 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
namespace {
|
||||
using ExpectFunc = std::function<bool(const CNodePtr &)>;
|
||||
}
|
||||
static void GenerateDefaultStrategy(const ValueNodePtr &axes, const std::vector<AnfNodePtr> &nodes,
|
||||
std::vector<std::vector<int64_t>> *default_strategy) {
|
||||
const size_t device_num, std::vector<std::vector<int64_t>> *default_strategy) {
|
||||
auto strategies = axes->value()->cast<ValueTuplePtr>()->value();
|
||||
size_t i = 0;
|
||||
for (auto &strategy : strategies) {
|
||||
auto node = nodes[i];
|
||||
if (strategy->isa<None>()) {
|
||||
(void)default_strategy->emplace_back(Shape());
|
||||
auto node_size = common::AnfAlgo::GetOutputInferShape(node, 0).size();
|
||||
std::vector<int64_t> current_d_strategy(node_size, 1);
|
||||
if (!current_d_strategy.empty()) {
|
||||
current_d_strategy[0] = SizeToLong(device_num);
|
||||
}
|
||||
(void)default_strategy->emplace_back(std::move(current_d_strategy));
|
||||
} else {
|
||||
(void)default_strategy->emplace_back(GetValue<Shape>(strategy));
|
||||
}
|
||||
|
@ -46,26 +54,13 @@ static void GenerateDefaultStrategy(const ValueNodePtr &axes, const std::vector<
|
|||
}
|
||||
}
|
||||
|
||||
// Generate strategies like ((), (), ..., ())
|
||||
Shapes GenerateEmptyStrategies(const CNodePtr &cnode) {
|
||||
auto shape_list = ExtractShape(cnode);
|
||||
if (shape_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << " failed to extract shape.";
|
||||
}
|
||||
return Shapes(shape_list[0].size(), Shape());
|
||||
}
|
||||
|
||||
static bool CheckOneDimensionalIntTuple(const ValuePtr &value_ptr) {
|
||||
if (!value_ptr->isa<ValueTuple>()) {
|
||||
return false;
|
||||
}
|
||||
auto elements = value_ptr->cast<ValueTuplePtr>()->value();
|
||||
for (auto &element : elements) {
|
||||
if (!element->isa<Int64Imm>()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return std::all_of(elements.begin(), elements.end(),
|
||||
[](const ValuePtr &element) { return element->isa<Int64Imm>(); });
|
||||
}
|
||||
|
||||
static bool CheckLayout(const ValueNodePtr &axes, bool *need_default_strategy, size_t *axes_size) {
|
||||
|
@ -83,12 +78,6 @@ static bool CheckLayout(const ValueNodePtr &axes, bool *need_default_strategy, s
|
|||
return true;
|
||||
}
|
||||
|
||||
static Shapes GenerateFullStrategy(const Shapes ¤t_strategy, const CNodePtr &cnode) {
|
||||
OperatorInfoPtr op_info = CreateOperatorInfo(cnode);
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
return op_info->GenerateFullStrategy(current_strategy);
|
||||
}
|
||||
|
||||
static void GetInputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *input_nodes) {
|
||||
auto parameters = func_graph->parameters();
|
||||
for (auto ¶meter : parameters) {
|
||||
|
@ -99,7 +88,7 @@ static void GetInputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr
|
|||
}
|
||||
}
|
||||
|
||||
static bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies, const int64_t &device_num) {
|
||||
static bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies, const int64_t device_num) {
|
||||
for (size_t i = 0; i < strategies.size(); ++i) {
|
||||
auto strategy = strategies[i];
|
||||
int64_t required_num = 1;
|
||||
|
@ -119,34 +108,6 @@ static bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies,
|
|||
return true;
|
||||
}
|
||||
|
||||
// Generate strategy for cnode by input_strategy.
|
||||
// For the i-th input:
|
||||
// 1. If it is specified in input_strategy, the strategy in input_strategy is used;
|
||||
// 2. Otherwise, its strategy is assigned as ()
|
||||
static Shapes GenerateDefaultStrategiesForCNode(const CNodePtr &cnode, const Shapes &input_strategy) {
|
||||
auto current_inputs = cnode->inputs();
|
||||
Shapes elements;
|
||||
for (size_t i = 1; i < current_inputs.size(); ++i) {
|
||||
auto current_input = current_inputs[i];
|
||||
if (current_input->isa<ValueNode>()) {
|
||||
auto current_value = current_input->cast<ValueNodePtr>()->value();
|
||||
if (!current_value->isa<mindspore::tensor::Tensor>()) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(current_input, prim::kPrimTupleGetItem)) {
|
||||
auto tuple_getitem_cnode = current_input->cast<CNodePtr>();
|
||||
auto tuple_index = tuple_getitem_cnode->input(2);
|
||||
auto value_node = tuple_index->cast<ValueNodePtr>();
|
||||
auto index = GetValue<int64_t>(value_node->value());
|
||||
elements.push_back(input_strategy[index]);
|
||||
} else {
|
||||
(void)elements.emplace_back(Shape());
|
||||
}
|
||||
}
|
||||
return elements;
|
||||
}
|
||||
|
||||
static ValueTuplePtr ShapesToValueTuplePtr(const Shapes &shapes) {
|
||||
std::vector<ValuePtr> value_list;
|
||||
(void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(value_list),
|
||||
|
@ -162,23 +123,22 @@ static Shapes ValueTuplePtrToShapes(const ValueTuplePtr &value_tuple_ptr) {
|
|||
return shapes;
|
||||
}
|
||||
|
||||
AnfNodeIndexSet FindAnfNodeIndexSetToInsertStrategy(const FuncGraphPtr &func_graph, const AnfNodeIndexSet &node_users) {
|
||||
AnfNodeIndexSet FindAnfNodeIndexSetToInsertStrategy(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const ExpectFunc &filter_func) {
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
AnfNodeIndexSet ret_set;
|
||||
std::queue<std::pair<AnfNodePtr, int>> bfs_list;
|
||||
auto node_users = manager->node_users()[node];
|
||||
std::queue<std::pair<AnfNodePtr, int>> bfs_queuq;
|
||||
(void)std::for_each(node_users.begin(), node_users.end(),
|
||||
[&bfs_list](const std::pair<AnfNodePtr, int> &user) { bfs_list.push(user); });
|
||||
|
||||
while (!bfs_list.empty()) {
|
||||
auto user = bfs_list.front();
|
||||
bfs_list.pop();
|
||||
CNodePtr cnode = user.first->cast<CNodePtr>();
|
||||
// If the cnode is not a splittable operator, apply strategy to the next cnode
|
||||
if (!IsSplittableOperator(GetPrimName(cnode)) || IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
|
||||
[&bfs_queuq](const std::pair<AnfNodePtr, int> &user) { bfs_queuq.push(user); });
|
||||
while (!bfs_queuq.empty()) {
|
||||
auto user = bfs_queuq.front();
|
||||
bfs_queuq.pop();
|
||||
auto cnode = user.first->cast<CNodePtr>();
|
||||
if (!filter_func(cnode)) {
|
||||
auto tmp_users = manager->node_users()[cnode];
|
||||
(void)std::for_each(tmp_users.begin(), tmp_users.end(),
|
||||
[&bfs_list](const std::pair<AnfNodePtr, int> &user) { bfs_list.push(user); });
|
||||
[&bfs_queuq](const std::pair<AnfNodePtr, int> &user) { bfs_queuq.push(user); });
|
||||
continue;
|
||||
}
|
||||
ret_set.insert(user);
|
||||
|
@ -186,6 +146,25 @@ AnfNodeIndexSet FindAnfNodeIndexSetToInsertStrategy(const FuncGraphPtr &func_gra
|
|||
return ret_set;
|
||||
}
|
||||
|
||||
bool IsSettingStrategyByInsertIdentity(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::string ¶m_name) {
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
auto node_users = manager->node_users()[cnode];
|
||||
for (const auto &user : node_users) {
|
||||
auto user_node = user.first;
|
||||
if (IsPrimitiveCNode(user_node, prim::kPrimIdentity)) {
|
||||
auto attrs = GetCNodePrimitive(user_node)->attrs();
|
||||
if (StrategyFound(attrs)) {
|
||||
auto origin_strategies = ValueTuplePtrToShapes(attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
|
||||
MS_LOG(WARNING) << "For " << param_name << ", its strategy has been set to " << origin_strategies.at(0)
|
||||
<< ", the relevant settings in input_strategy will be ignored";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// New a primitive for cnode and set in_strategy to it.
|
||||
void SetStrategyToCNode(const CNodePtr &cnode, const Shapes &strategies) {
|
||||
auto strategy = ShapesToValueTuplePtr(strategies);
|
||||
|
@ -193,7 +172,7 @@ void SetStrategyToCNode(const CNodePtr &cnode, const Shapes &strategies) {
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
PrimitivePtr new_prim;
|
||||
if (prim->isa<PrimitivePy>()) {
|
||||
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
|
||||
auto prim_py = prim->cast<PrimitivePyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_py);
|
||||
new_prim = std::make_shared<PrimitivePy>(*prim_py);
|
||||
} else {
|
||||
|
@ -205,13 +184,12 @@ void SetStrategyToCNode(const CNodePtr &cnode, const Shapes &strategies) {
|
|||
|
||||
ValuePtr new_prim_value = MakeValue(new_prim);
|
||||
ValueNodePtr new_prim_value_node = NewValueNode(new_prim_value);
|
||||
AnfNodePtr new_prim_anf_node = new_prim_value_node->cast<AnfNodePtr>();
|
||||
auto new_prim_anf_node = new_prim_value_node->cast<AnfNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_prim_anf_node);
|
||||
cnode->set_input(0, new_prim_anf_node);
|
||||
}
|
||||
|
||||
static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_strategy,
|
||||
const int64_t &device_num) {
|
||||
void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_strategy, const int64_t device_num) {
|
||||
auto in_strategy_tuple = in_strategy->cast<ValueNodePtr>();
|
||||
bool need_default_strategy = false;
|
||||
size_t in_strategy_size = 0;
|
||||
|
@ -227,14 +205,14 @@ static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const A
|
|||
}
|
||||
std::vector<std::vector<int64_t>> input_strategy;
|
||||
if (need_default_strategy) {
|
||||
GenerateDefaultStrategy(in_strategy_tuple, input_nodes, &input_strategy);
|
||||
GenerateDefaultStrategy(in_strategy_tuple, input_nodes, device_num, &input_strategy);
|
||||
} else {
|
||||
input_strategy = GetValue<std::vector<std::vector<int64_t>>>(in_strategy_tuple->value());
|
||||
}
|
||||
if (!CheckDeviceNum(input_strategy, device_num)) {
|
||||
MS_LOG(EXCEPTION) << "check device number failed";
|
||||
}
|
||||
std::set<CNodePtr> concerned_nodes;
|
||||
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
auto parameters = func_graph->parameters();
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
|
@ -250,105 +228,76 @@ static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const A
|
|||
<< " is not equal to in_strategy dimension: " << input_strategy[i].size() << " at index " << i;
|
||||
}
|
||||
AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
|
||||
auto to_insert_nodes_set = FindAnfNodeIndexSetToInsertStrategy(func_graph, param_sub_set);
|
||||
auto to_insert_nodes_set = FindAnfNodeIndexSetToInsertStrategy(
|
||||
func_graph, parameter, [](const CNodePtr &cnode) { return IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem); });
|
||||
if (to_insert_nodes_set.empty()) {
|
||||
MS_LOG(EXCEPTION) << "For input: \"" << parameter->fullname_with_scope()
|
||||
<< "\", failed to find node to insert strategy.";
|
||||
}
|
||||
for (auto &node : to_insert_nodes_set) {
|
||||
CNodePtr param_cnode = node.first->cast<CNodePtr>();
|
||||
auto param_attrs = GetCNodePrimitive(param_cnode)->attrs();
|
||||
if (StrategyFound(param_attrs)) {
|
||||
auto origin_strategies = ValueTuplePtrToShapes(param_attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
|
||||
MS_LOG(WARNING) << "For " << param_cnode->fullname_with_scope() << ", its in_strategy has been set to "
|
||||
<< origin_strategies << ", the relevant settings in input_strategy will be ignored";
|
||||
auto tuple_get_item_cnode = node.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_cnode);
|
||||
if (IsSettingStrategyByInsertIdentity(func_graph, tuple_get_item_cnode, parameter->fullname_with_scope())) {
|
||||
continue;
|
||||
}
|
||||
(void)concerned_nodes.insert(param_cnode);
|
||||
|
||||
// Setting strategy by insert identity.
|
||||
// e.g TupleGetItem(parameter, index) -> identity{in_strategy=[input_strategy[index]}
|
||||
auto identity_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimIdentity), tuple_get_item_cnode});
|
||||
auto tuple_get_item_cnode_abstract = tuple_get_item_cnode->abstract();
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_cnode_abstract);
|
||||
identity_cnode->set_abstract(tuple_get_item_cnode_abstract->Clone());
|
||||
manager->Replace(tuple_get_item_cnode, identity_cnode);
|
||||
|
||||
// Get corresponding param_layout
|
||||
auto tuple_index = tuple_get_item_cnode->input(2);
|
||||
auto value_node = tuple_index->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto index = GetValue<int64_t>(value_node->value());
|
||||
Shapes current_strategies = {input_strategy[index]};
|
||||
SetStrategyToCNode(identity_cnode, current_strategies);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &cnode : concerned_nodes) {
|
||||
Shapes ret_strategy = GenerateDefaultStrategiesForCNode(cnode, input_strategy);
|
||||
SetStrategyToCNode(cnode, ret_strategy);
|
||||
}
|
||||
return concerned_nodes;
|
||||
}
|
||||
|
||||
static std::set<CNodePtr> SetParameterLayout(const FuncGraphPtr &root, const FuncGraphPtr &func_graph,
|
||||
const std::set<CNodePtr> &input_concerned_node) {
|
||||
void SetParameterLayout(const FuncGraphPtr &root, const FuncGraphPtr &func_graph) {
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
auto root_parameters = root->parameters();
|
||||
std::set<CNodePtr> concerned_cnode;
|
||||
for (auto param : root_parameters) {
|
||||
for (const auto ¶m : root_parameters) {
|
||||
auto parameter = param->cast<ParameterPtr>();
|
||||
auto param_info = parameter->param_info();
|
||||
if (param_info == nullptr || param_info->param_strategy().empty()) {
|
||||
// Do not set param_strategy, skip it.
|
||||
continue;
|
||||
}
|
||||
auto param_strategy = parameter->param_info()->param_strategy();
|
||||
auto param_name = parameter->param_info()->name();
|
||||
auto param_strategy = param_info->param_strategy();
|
||||
auto param_name = param_info->name();
|
||||
AnfNodeIndexSet users = manager->node_users()[parameter];
|
||||
auto to_insert_nodes_set = FindAnfNodeIndexSetToInsertStrategy(func_graph, users);
|
||||
for (auto user : to_insert_nodes_set) {
|
||||
CNodePtr target_cnode = user.first->cast<CNodePtr>();
|
||||
Shapes current_strategies;
|
||||
if (input_concerned_node.find(target_cnode) == input_concerned_node.end()) {
|
||||
// If target_cnode is not involve inputs, insert an identity between Load and target_cnode,
|
||||
// and setting layout into identity.
|
||||
// e.g Load(param) -> identity{in_strategy} -> target_cnode
|
||||
auto pre_cnode = target_cnode->input(user.second)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(pre_cnode);
|
||||
if (IsPrimitiveCNode(pre_cnode, prim::kPrimCast)) {
|
||||
pre_cnode = pre_cnode->inputs().at(kIndex1)->cast<CNodePtr>();
|
||||
}
|
||||
if (!IsPrimitiveCNode(pre_cnode, prim::kPrimLoad)) {
|
||||
MS_LOG(EXCEPTION) << "The operator type of the " << user.second << "-th input in "
|
||||
<< target_cnode->fullname_with_scope() << " must be 'Load', but got "
|
||||
<< GetCNodePrimitive(pre_cnode)->ToString();
|
||||
}
|
||||
auto identity_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimIdentity), pre_cnode});
|
||||
auto pre_cnode_abstract = pre_cnode->abstract();
|
||||
MS_EXCEPTION_IF_NULL(pre_cnode_abstract);
|
||||
identity_cnode->set_abstract(pre_cnode_abstract->Clone());
|
||||
manager->Replace(pre_cnode, identity_cnode);
|
||||
target_cnode = identity_cnode;
|
||||
current_strategies = {param_strategy};
|
||||
} else {
|
||||
// Setting layout into target_cnode directly.
|
||||
PrimitivePtr prim = GetCNodePrimitive(target_cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto attrs = prim->attrs();
|
||||
if (StrategyFound(attrs)) {
|
||||
current_strategies = ValueTuplePtrToShapes(attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
|
||||
} else {
|
||||
current_strategies = GenerateEmptyStrategies(target_cnode);
|
||||
}
|
||||
current_strategies[user.second - 1] = param_strategy;
|
||||
(void)concerned_cnode.insert(target_cnode);
|
||||
auto to_insert_nodes_set = FindAnfNodeIndexSetToInsertStrategy(
|
||||
func_graph, parameter, [](const CNodePtr &cnode) { return IsPrimitiveCNode(cnode, prim::kPrimLoad); });
|
||||
for (const auto &user : to_insert_nodes_set) {
|
||||
auto load_cnode = user.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(load_cnode);
|
||||
if (IsSettingStrategyByInsertIdentity(func_graph, load_cnode, param_name)) {
|
||||
continue;
|
||||
}
|
||||
SetStrategyToCNode(target_cnode, current_strategies);
|
||||
MS_LOG(DEBUG) << "The layout of \"" << param_name << "\" has been set to the " << user.second << "th of "
|
||||
<< target_cnode->fullname_with_scope() << "'s in_strategy. Current strategies is "
|
||||
<< current_strategies;
|
||||
}
|
||||
}
|
||||
return concerned_cnode;
|
||||
}
|
||||
|
||||
void CompleteConcernedCNodeStrategies(std::set<CNodePtr> concerned_cnode) {
|
||||
for (auto cnode : concerned_cnode) {
|
||||
PrimitivePtr prim = GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto attrs = prim->attrs();
|
||||
Shapes current_strategies = ValueTuplePtrToShapes(attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
|
||||
Shapes full_strategies = GenerateFullStrategy(current_strategies, cnode);
|
||||
attrs[parallel::IN_STRATEGY] = ShapesToValueTuplePtr(full_strategies);
|
||||
(void)prim->SetAttrs(attrs);
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << ": Completion strategies success. " << current_strategies << " -> "
|
||||
<< full_strategies << "(origin_strategies -> completion_strategies)";
|
||||
// Setting param_layout by insert identity. e.g Load(param) -> identity{in_strategy=[param_layout]}
|
||||
auto identity_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimIdentity), load_cnode});
|
||||
auto load_cnode_abstract = load_cnode->abstract();
|
||||
MS_EXCEPTION_IF_NULL(load_cnode_abstract);
|
||||
identity_cnode->set_abstract(load_cnode_abstract->Clone());
|
||||
manager->Replace(load_cnode, identity_cnode);
|
||||
Shapes current_strategies = {param_strategy};
|
||||
SetStrategyToCNode(identity_cnode, current_strategies);
|
||||
MS_LOG(DEBUG) << "The layout of \"" << param_name << "\" has been set to "
|
||||
<< identity_cnode->fullname_with_scope() << ". Current strategies is " << current_strategies;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bool SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
||||
const int64_t &device_num) {
|
||||
const int64_t device_num) {
|
||||
constexpr size_t kShardFnIndex = 1;
|
||||
constexpr size_t kShardInStrategyIndex = 2;
|
||||
for (auto &node : all_nodes) {
|
||||
|
@ -366,13 +315,8 @@ static bool SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfN
|
|||
if (HasNestedMetaFg(func_graph)) {
|
||||
return false;
|
||||
}
|
||||
std::set<CNodePtr> concerned_cnode;
|
||||
auto input_concerned_cnode = SetInputLayout(func_graph, in_strategy, device_num);
|
||||
auto parameter_concerned_cnode = SetParameterLayout(root, func_graph, input_concerned_cnode);
|
||||
(void)std::set_union(input_concerned_cnode.begin(), input_concerned_cnode.end(),
|
||||
parameter_concerned_cnode.begin(), parameter_concerned_cnode.end(),
|
||||
std::inserter(concerned_cnode, concerned_cnode.end()));
|
||||
CompleteConcernedCNodeStrategies(concerned_cnode);
|
||||
SetInputLayout(func_graph, in_strategy, device_num);
|
||||
SetParameterLayout(root, func_graph);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -199,6 +199,10 @@ void ExecuteActionForMindRT(const ResourcePtr &resource) {
|
|||
void ModifyOutputNode(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &used_forward_nodes = func_graph->used_forward_nodes();
|
||||
std::vector<AnfNodePtr> used_forward_nodes_sorted_list(used_forward_nodes.begin(), used_forward_nodes.end());
|
||||
std::sort(
|
||||
used_forward_nodes_sorted_list.begin(), used_forward_nodes_sorted_list.end(),
|
||||
[](const AnfNodePtr &a, const AnfNodePtr &b) { return a->fullname_with_scope() < b->fullname_with_scope(); });
|
||||
|
||||
// Get original output node and abstract
|
||||
auto original_output_node = func_graph->output();
|
||||
|
@ -209,7 +213,7 @@ void ModifyOutputNode(const FuncGraphPtr &func_graph) {
|
|||
// Create a new make tuple node to hold all forward used nodes.
|
||||
abstract::AbstractBasePtrList added_abs_list;
|
||||
std::vector<AnfNodePtr> added_node_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
std::for_each(used_forward_nodes.begin(), used_forward_nodes.end(),
|
||||
std::for_each(used_forward_nodes_sorted_list.begin(), used_forward_nodes_sorted_list.end(),
|
||||
[&added_abs_list, &added_node_list](const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
added_node_list.push_back(node);
|
||||
|
|
|
@ -374,7 +374,7 @@ def train_feed(num_classes, expect_out):
|
|||
model.train(2, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
|
||||
loss_value = np.array(parallel_callback.loss_list)
|
||||
print(loss_value)
|
||||
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
|
||||
assert np.allclose(loss_value, expect_out, 0.001, 0.001)
|
||||
|
||||
|
||||
def test_train_feed_ascend():
|
||||
|
@ -391,7 +391,7 @@ def test_train_feed_ascend():
|
|||
dataset_strategy="data_parallel")
|
||||
np.random.seed(42)
|
||||
set_seed(42)
|
||||
train_feed(num_classes=65536, expect_out=[11.32993, 10.7269535])
|
||||
train_feed(num_classes=65536, expect_out=[11.275127, 8.720392])
|
||||
|
||||
|
||||
def test_train_feed_gpu():
|
||||
|
@ -406,6 +406,6 @@ def test_train_feed_gpu():
|
|||
context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL,
|
||||
search_mode="sharding_propagation", device_num=8,
|
||||
dataset_strategy="data_parallel")
|
||||
np.random.seed(42)
|
||||
set_seed(42)
|
||||
train_feed(num_classes=65536, expect_out=[53.35976, 54.689503])
|
||||
np.random.seed(1)
|
||||
set_seed(1)
|
||||
train_feed(num_classes=65536, expect_out=[53.538628, 54.83031])
|
||||
|
|
|
@ -1,112 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/ops_info/arithmetic_info.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
||||
class AddInfo;
|
||||
using AddInfoPtr = std::shared_ptr<AddInfo>;
|
||||
AddInfoPtr add, add1;
|
||||
|
||||
class TestInferStrategyBroadcastMode : public UT::Common {
|
||||
public:
|
||||
TestInferStrategyBroadcastMode() {}
|
||||
void SetUp();
|
||||
void TearDown() {}
|
||||
};
|
||||
|
||||
void TestInferStrategyBroadcastMode::SetUp() {
|
||||
RankList dev_list;
|
||||
|
||||
for (int32_t i = 0; i < 34; i++) {
|
||||
dev_list.push_back(i);
|
||||
}
|
||||
|
||||
RankList stage_map;
|
||||
stage_map.push_back(32);
|
||||
stage_map.push_back(2);
|
||||
|
||||
int32_t local_dev = 0;
|
||||
|
||||
// create a new g_device_manager
|
||||
g_device_manager = std::make_shared<DeviceManager>();
|
||||
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
|
||||
|
||||
mindspore::HashMap<std::string, ValuePtr> attr;
|
||||
|
||||
Shapes inputs_shape = {{32, 64, 96}, {32, 64, 96}};
|
||||
Shapes outputs_shape = {{32, 64, 96}};
|
||||
add = std::make_shared<AddInfo>("tensoradd_info", inputs_shape, outputs_shape, attr);
|
||||
|
||||
Shapes inputs_shape1 = {{16, 32, 48}, {32, 48}};
|
||||
Shapes outputs_shape1 = {{16, 32, 48}};
|
||||
add1 = std::make_shared<AddInfo>("tensoradd_info", inputs_shape1, outputs_shape1, attr);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for broadcast mode
|
||||
/// Description: the in strategy is {{2, 4, 4}, {}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
|
||||
/// Expectation: the return strategy is {{2, 4, 4}, {2, 4, 4}}
|
||||
TEST_F(TestInferStrategyBroadcastMode, GenerateFullStrategy1) {
|
||||
Strategies in_strategy = {{2, 4, 4}, {}};
|
||||
Strategies ret = add->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 4}, {2, 4, 4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for broadcast mode
|
||||
/// Description: the in strategy is {{}, {2, 4, 4}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
|
||||
/// Expectation: the return strategy is {{2, 4, 4}, {2, 4, 4}}
|
||||
TEST_F(TestInferStrategyBroadcastMode, GenerateFullStrategy2) {
|
||||
Strategies in_strategy = {{}, {2, 4, 4}};
|
||||
Strategies ret = add->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 4}, {2, 4, 4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for broadcast mode
|
||||
/// Description: the in strategy is {{2, 4, 4}, {}}, the in shapes is {{16, 32, 48}, {32, 48}}
|
||||
/// Expectation: the return strategy is {{2, 4, 4}, {4, 4}}
|
||||
TEST_F(TestInferStrategyBroadcastMode, GenerateFullStrategy3) {
|
||||
Strategies in_strategy = {{2, 4, 4}, {}};
|
||||
Strategies ret = add1->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 4}, {4, 4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for broadcast mode
|
||||
/// Description: the in strategy is {{}, {4, 4}}, the in shapes is {{16, 32, 48}, {32, 48}}
|
||||
/// Expectation: the return strategy is {{1, 4, 4}, {4, 4}}
|
||||
TEST_F(TestInferStrategyBroadcastMode, GenerateFullStrategy4) {
|
||||
Strategies in_strategy = {{}, {4, 4}};
|
||||
Strategies ret = add1->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 4, 4}, {4, 4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -1,86 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/ops_info/gathernd_info.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
||||
class GatherNdInfo;
|
||||
using GatherNdInfoPtr = std::shared_ptr<GatherNdInfo>;
|
||||
GatherNdInfoPtr gathernd;
|
||||
|
||||
class TestInferStrategyIndependentMode : public UT::Common {
|
||||
public:
|
||||
TestInferStrategyIndependentMode() {}
|
||||
void SetUp();
|
||||
void TearDown() {}
|
||||
};
|
||||
|
||||
void TestInferStrategyIndependentMode::SetUp() {
|
||||
RankList dev_list;
|
||||
|
||||
for (int32_t i = 0; i < 34; i++) {
|
||||
dev_list.push_back(i);
|
||||
}
|
||||
|
||||
RankList stage_map;
|
||||
stage_map.push_back(32);
|
||||
stage_map.push_back(2);
|
||||
|
||||
int32_t local_dev = 0;
|
||||
|
||||
// create a new g_device_manager
|
||||
g_device_manager = std::make_shared<DeviceManager>();
|
||||
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
|
||||
|
||||
mindspore::HashMap<std::string, ValuePtr> attr;
|
||||
|
||||
Shapes inputs_shape = {{32, 64, 96}, {32, 64, 96}};
|
||||
Shapes outputs_shape = {{32, 64, 96}};
|
||||
gathernd = std::make_shared<GatherNdInfo>("gathernd_info", inputs_shape, outputs_shape, attr);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for independent mode
|
||||
/// Description: the in strategy is {{1, 1, 1}, {}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
|
||||
/// Expectation: the return strategy is {{1, 1, 1}, {1, 1, 1}}
|
||||
TEST_F(TestInferStrategyIndependentMode, GenerateFullStrategy1) {
|
||||
Strategies in_strategy = {{1, 1, 1}, {}};
|
||||
Strategies ret = gathernd->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 1, 1}, {1, 1, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for independent mode
|
||||
/// Description: the in strategy is {{}, {2, 4, 1}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
|
||||
/// Expectation: the return strategy is {{1, 1, 1}, {2, 4, 1}}
|
||||
TEST_F(TestInferStrategyIndependentMode, GenerateFullStrategy2) {
|
||||
Strategies in_strategy = {{}, {2, 4, 1}};
|
||||
Strategies ret = gathernd->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 1, 1}, {2, 4, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -1,264 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/ops_info/layer_norm_info.h"
|
||||
#include "frontend/parallel/ops_info/batchnorm_info.h"
|
||||
#include "frontend/parallel/ops_info/bias_add_info.h"
|
||||
#include "frontend/parallel/ops_info/scatter_ops_info.h"
|
||||
#include "frontend/parallel/ops_info/conv2d_info.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
||||
class LayerNormInfo;
|
||||
class BiasAddInfo;
|
||||
class BatchNormInfo;
|
||||
class ScatterUpdateInfo;
|
||||
class Conv2DInfo;
|
||||
using LayerNormInfoPtr = std::shared_ptr<LayerNormInfo>;
|
||||
using BiasAddInfoPtr = std::shared_ptr<BiasAddInfo>;
|
||||
using BatchNormInfoPtr = std::shared_ptr<BatchNormInfo>;
|
||||
using ScatterUpdateInfoPtr = std::shared_ptr<ScatterUpdateInfo>;
|
||||
using Conv2DInfoPtr = std::shared_ptr<Conv2DInfo>;
|
||||
LayerNormInfoPtr layer_norm;
|
||||
BiasAddInfoPtr bias_add;
|
||||
BatchNormInfoPtr batch_norm;
|
||||
ScatterUpdateInfoPtr scatter_update;
|
||||
Conv2DInfoPtr conv2d;
|
||||
|
||||
class TestInferStrategyIndividualMode : public UT::Common {
|
||||
public:
|
||||
TestInferStrategyIndividualMode() {}
|
||||
void SetUp();
|
||||
void TearDown() {}
|
||||
};
|
||||
|
||||
void TestInferStrategyIndividualMode::SetUp() {
|
||||
RankList dev_list;
|
||||
|
||||
for (int32_t i = 0; i < 64; i++) {
|
||||
dev_list.push_back(i);
|
||||
}
|
||||
|
||||
RankList stage_map;
|
||||
stage_map.push_back(64);
|
||||
|
||||
int32_t local_dev = 0;
|
||||
|
||||
// create a new g_device_manager
|
||||
g_device_manager = std::make_shared<DeviceManager>();
|
||||
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
|
||||
|
||||
// layer_norm
|
||||
ValuePtr begin_norm_axis = MakeValue(std::int64_t(3));
|
||||
mindspore::HashMap<std::string, ValuePtr> attr_1 = {{"begin_norm_axis", begin_norm_axis}};
|
||||
|
||||
Shapes ln_inputs_shape = {{16, 32, 64, 96}, {32, 64, 96}, {32, 64, 96}};
|
||||
Shapes ln_outputs_shape = {{16, 32, 64, 96}, {16, 32, 1, 1}, {16, 32, 1, 1}};
|
||||
layer_norm = std::make_shared<LayerNormInfo>("layernorm_info", ln_inputs_shape, ln_outputs_shape, attr_1);
|
||||
|
||||
// bias_add
|
||||
mindspore::HashMap<std::string, ValuePtr> attr_2;
|
||||
Shapes ba_inputs_shape = {{64, 96}, {96}};
|
||||
Shapes ba_outputs_shape = {{64, 96}};
|
||||
bias_add = std::make_shared<BiasAddInfo>("biasadd_info", ba_inputs_shape, ba_outputs_shape, attr_2);
|
||||
|
||||
// batch_norm
|
||||
ValuePtr is_training = MakeValue(true);
|
||||
ValuePtr epsilon = MakeValue(std::float_t(1.0));
|
||||
ValuePtr momentum = MakeValue(std::float_t(1.0));
|
||||
ValuePtr format = MakeValue("NCHW");
|
||||
mindspore::HashMap<std::string, ValuePtr> attr_3 = {{"is_training", is_training},
|
||||
{"epsilon", epsilon},
|
||||
{"momentum", momentum},
|
||||
{"format", format}};
|
||||
|
||||
Shapes bn_inputs_shape = {{64, 96, 32, 16}, {96}, {96}, {96}, {96}};
|
||||
Shapes bn_outputs_shape = {{64, 96, 32, 16}, {96}, {96}, {96}, {96}};
|
||||
batch_norm = std::make_shared<BatchNormInfo>("batchnorm_info", bn_inputs_shape, bn_outputs_shape, attr_3);
|
||||
|
||||
// scatter_update
|
||||
mindspore::HashMap<std::string, ValuePtr> attr_4;
|
||||
Shapes su_inputs_shape = {{16, 32, 64, 96}, {128, 256}, {128, 256, 32, 64, 96}};
|
||||
Shapes su_outputs_shape = {{16, 32, 64, 96}};
|
||||
scatter_update = std::make_shared<ScatterUpdateInfo>("scatterupdate_info", su_inputs_shape, su_outputs_shape, attr_4);
|
||||
|
||||
// conv2d
|
||||
ValuePtr out_channel = MakeValue(std::int64_t(10));
|
||||
ValuePtr kernel_size = MakeValue(std::vector<int64_t>{4, 4});
|
||||
ValuePtr mode = MakeValue(std::int64_t(1));
|
||||
ValuePtr pad_mode = MakeValue(std::int64_t(1));
|
||||
ValuePtr pad_list = MakeValue(std::vector<int64_t>{1, 1, 1, 1});
|
||||
ValuePtr stride = MakeValue(std::vector<int64_t>{1, 1, 2, 2});
|
||||
ValuePtr dilation = MakeValue(std::vector<int64_t>{1, 1, 1, 1});
|
||||
ValuePtr group = MakeValue(std::int64_t(1));
|
||||
mindspore::HashMap<std::string, ValuePtr> attr_5 = {{"out_channel", out_channel},
|
||||
{"kernel_size", kernel_size},
|
||||
{"pad_mode", pad_mode},
|
||||
{"mode", mode},
|
||||
{"pad_list", pad_list},
|
||||
{"stride", stride},
|
||||
{"dilation", dilation},
|
||||
{"group", group},
|
||||
{"format", format}};
|
||||
|
||||
Shapes conv_inputs_shape = {{128, 2, 16, 16}, {10, 2, 4, 4}};
|
||||
Shapes conv_outputs_shape = {{128, 10, 8, 8}};
|
||||
conv2d = std::make_shared<Conv2DInfo>("conv2d_info", conv_inputs_shape, conv_outputs_shape, attr_5);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for layer_norm
|
||||
/// Description: the in strategy is {{2, 4, 8, 1}, {}, {}}
|
||||
/// Expectation: the return strategy is {{2, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy1) {
|
||||
Strategies in_strategy = {{2, 4, 8, 1}, {}, {}};
|
||||
Strategies ret = layer_norm->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for layer_norm
|
||||
/// Description: the in strategy is {{}, {4, 8, 1}, {4, 8, 1}}
|
||||
/// Expectation: the return strategy is {{1, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy2) {
|
||||
Strategies in_strategy = {{}, {4, 8, 1}, {4, 8, 1}};
|
||||
Strategies ret = layer_norm->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for layer_norm
|
||||
/// Description: the in strategy is {{}, {4, 8, 1}, {}}
|
||||
/// Expectation: throw exception
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy3) {
|
||||
Strategies in_strategy = {{}, {4, 8, 1}, {}};
|
||||
ASSERT_ANY_THROW(layer_norm->GenerateFullStrategy(in_strategy));
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for bias_add
|
||||
/// Description: the in strategy is {{4, 8}, {}}
|
||||
/// Expectation: the return strategy is {{4, 8}, {8}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy4) {
|
||||
Strategies in_strategy = {{4, 8}, {}};
|
||||
Strategies ret = bias_add->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{4, 8}, {8}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for bias_add
|
||||
/// Description: the in strategy is {{}, {8}}
|
||||
/// Expectation: the return strategy is {{1, 8}, {8}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy5) {
|
||||
Strategies in_strategy = {{}, {8}};
|
||||
Strategies ret = bias_add->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 8}, {8}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for batch_norm
|
||||
/// Description: the in strategy is {{2, 4, 8, 16}, {}, {}, {}, {}}
|
||||
/// Expectation: the return strategy is {{2, 4, 8, 16}, {4}, {4}, {4}, {4}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy6) {
|
||||
Strategies in_strategy = {{2, 4, 8, 16}, {}, {}, {}, {}};
|
||||
Strategies ret = batch_norm->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 8, 16}, {4}, {4}, {4}, {4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for batch_norm
|
||||
/// Description: the in strategy is {{}, {4}, {4}, {4}, {4}}
|
||||
/// Expectation: the return strategy is {{1, 4, 1, 1}, {4}, {4}, {4}, {4}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy7) {
|
||||
Strategies in_strategy = {{}, {4}, {4}, {4}, {4}};
|
||||
Strategies ret = batch_norm->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 4, 1, 1}, {4}, {4}, {4}, {4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for batch_norm
|
||||
/// Description: the in strategy is {{}, {4}, {}, {}, {4}}
|
||||
/// Expectation: throw exception
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy8) {
|
||||
Strategies in_strategy = {{}, {4}, {}, {}, {4}};
|
||||
ASSERT_ANY_THROW(batch_norm->GenerateFullStrategy(in_strategy));
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for scatter_update
|
||||
/// Description: the in strategy is {{1, 4, 8, 1}, {}, {}}
|
||||
/// Expectation: the return strategy is {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy9) {
|
||||
Strategies in_strategy = {{1, 4, 8, 1}, {}, {}};
|
||||
Strategies ret = scatter_update->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for scatter_update
|
||||
/// Description: the in strategy is {{}, {1, 1}, {1, 1, 4, 8, 1}}
|
||||
/// Expectation: the return strategy is {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy10) {
|
||||
Strategies in_strategy = {{}, {1, 1}, {1, 1, 4, 8, 1}};
|
||||
Strategies ret = scatter_update->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for scatter_update
|
||||
/// Description: the in strategy is {{}, {1, 1}, {}}
|
||||
/// Expectation: throw exception
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy11) {
|
||||
Strategies in_strategy = {{}, {1, 1}, {}};
|
||||
ASSERT_ANY_THROW(scatter_update->GenerateFullStrategy(in_strategy));
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for conv2d
|
||||
/// Description: the in strategy is {{8, 2, 1, 1}, {}}
|
||||
/// Expectation: the return strategy is {{8, 2, 1, 1}, {1, 2, 1, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy12) {
|
||||
Strategies in_strategy = {{8, 2, 1, 1}, {}};
|
||||
Strategies ret = conv2d->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{8, 2, 1, 1}, {1, 2, 1, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for conv2d
|
||||
/// Description: the in strategy is {{}, {1, 2, 1, 1}}
|
||||
/// Expectation: the return strategy is {{1, 2, 1, 1}, {1, 2, 1, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy13) {
|
||||
Strategies in_strategy = {{}, {1, 2, 1, 1}};
|
||||
Strategies ret = conv2d->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 2, 1, 1}, {1, 2, 1, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -1,86 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/ops_info/addn_info.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
||||
class AddNInfo;
|
||||
using AddNInfoPtr = std::shared_ptr<AddNInfo>;
|
||||
AddNInfoPtr addn;
|
||||
|
||||
class TestInferStrategySameMode : public UT::Common {
|
||||
public:
|
||||
TestInferStrategySameMode() {}
|
||||
void SetUp();
|
||||
void TearDown() {}
|
||||
};
|
||||
|
||||
void TestInferStrategySameMode::SetUp() {
|
||||
RankList dev_list;
|
||||
|
||||
for (int32_t i = 0; i < 34; i++) {
|
||||
dev_list.push_back(i);
|
||||
}
|
||||
|
||||
RankList stage_map;
|
||||
stage_map.push_back(32);
|
||||
stage_map.push_back(2);
|
||||
|
||||
int32_t local_dev = 0;
|
||||
|
||||
// create a new g_device_manager
|
||||
g_device_manager = std::make_shared<DeviceManager>();
|
||||
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
|
||||
|
||||
mindspore::HashMap<std::string, ValuePtr> attr;
|
||||
|
||||
Shapes inputs_shape = {{32, 64, 96}, {32, 64, 96}};
|
||||
Shapes outputs_shape = {{32, 64, 96}};
|
||||
addn = std::make_shared<AddNInfo>("addn_info", inputs_shape, outputs_shape, attr);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for same mode
|
||||
/// Description: the in strategy is {{2, 4, 4}, {}}
|
||||
/// Expectation: the return strategy is {{2, 4, 4}, {2, 4, 4}}
|
||||
TEST_F(TestInferStrategySameMode, GenerateFullStrategy1) {
|
||||
Strategies in_strategy = {{2, 4, 4}, {}};
|
||||
Strategies ret = addn->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 4}, {2, 4, 4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for same mode
|
||||
/// Description: the in strategy is {{}, {2, 4, 4}}
|
||||
/// Expectation: the return strategy is {{2, 4, 4}, {2, 4, 4}}
|
||||
TEST_F(TestInferStrategySameMode, GenerateFullStrategy2) {
|
||||
Strategies in_strategy = {{}, {2, 4, 4}};
|
||||
Strategies ret = addn->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 4}, {2, 4, 4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -648,7 +648,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
|
|||
// the parameter '0' indicates that the stageId = 0, there are 1024 devices in the stage 0
|
||||
ASSERT_EQ(matmul1->GenerateStrategies(0), Status::SUCCESS);
|
||||
std::vector<std::shared_ptr<StrategyWithCost>> sc = matmul1->GetStrategyCost();
|
||||
for (const auto& swc : sc) {
|
||||
for (const auto &swc : sc) {
|
||||
StrategyPtr sp = swc->strategy_ptr;
|
||||
Cost cost = *(swc->cost_list[0]);
|
||||
matmul1->InitForCostModel(sp, nullptr);
|
||||
|
@ -659,82 +659,5 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{2, 4, 8, 16}, {2, 4, 16, 32}}, transpose_b=false
|
||||
/// Description: the in strategy is {{2, 4, 8, 16}, {}}
|
||||
/// Expectation: the return strategy is {{2, 4, 8, 16}, {2, 4, 16, 1}}
|
||||
TEST_F(TestMatmulInfo, GenerateFullStrategy1) {
|
||||
Strategies in_strategy = {{2, 4, 8, 16}, {}};
|
||||
Strategies ret = matmul1->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{2, 4, 8, 16}, {2, 4, 16, 32}}, transpose_b=false
|
||||
/// Description: the in strategy is {{}, {2, 4, 8, 16}}
|
||||
/// Expectation: the return strategy is {{2, 4, 1, 8}, {2, 4, 8, 16}}
|
||||
TEST_F(TestMatmulInfo, GenerateFullStrategy2) {
|
||||
Strategies in_strategy = {{}, {2, 4, 8, 16}};
|
||||
Strategies ret = matmul1->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 1, 8}, {2, 4, 8, 16}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{2, 4, 8, 16}, {32, 16}}, transpose_b=true
|
||||
/// Description: the in strategy is {{2, 4, 8, 16}, {}}
|
||||
/// Expectation: the return strategy is {{2, 4, 8, 16}, {1, 16}}
|
||||
TEST_F(TestMatmulInfo, GenerateFullStrategy3) {
|
||||
Strategies in_strategy = {{2, 4, 8, 16}, {}};
|
||||
Strategies ret = matmul2->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 8, 16}, {1, 16}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{2, 4, 8, 16}, {32, 16}}, transpose_b=true
|
||||
/// Description: the in strategy is {{}, {8, 16}}
|
||||
/// Expectation: the return strategy is {{1, 1, 1, 16}, {8, 16}}
|
||||
TEST_F(TestMatmulInfo, GenerateFullStrategy4) {
|
||||
Strategies in_strategy = {{}, {8, 16}};
|
||||
Strategies ret = matmul2->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 1, 1, 16}, {8, 16}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{8, 16}, {2, 4, 32, 16}}, transpose_b=true
|
||||
/// Description: the in strategy is {{8, 16}, {}}
|
||||
/// Expectation: the return strategy is {{8, 16}, {1, 1, 1, 16}}
|
||||
TEST_F(TestMatmulInfo, GenerateFullStrategy5) {
|
||||
Strategies in_strategy = {{8, 16}, {}};
|
||||
Strategies ret = matmul3->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{8, 16}, {1, 1, 1, 16}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{8, 16}, {2, 4, 32, 16}}, transpose_b=true
|
||||
/// Description: the in strategy is {{}, {2, 4, 8, 16}}
|
||||
/// Expectation: the return strategy is {{1, 16}, {2, 4, 8, 16}}
|
||||
TEST_F(TestMatmulInfo, GenerateFullStrategy6) {
|
||||
Strategies in_strategy = {{}, {2, 4, 8, 16}};
|
||||
Strategies ret = matmul3->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 16}, {2, 4, 8, 16}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{1024, 128}, {128, 256}}, transpose_b=false
|
||||
/// Description: the in strategy is {{}, {}}
|
||||
/// Expectation: the return strategy is {{1024, 1}, {1, 1}}
|
||||
TEST_F(TestMatmulInfo, GenerateFullStrategy7) {
|
||||
Strategies in_strategy = {{}, {}};
|
||||
Strategies ret = matmul5->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1024, 1}, {1, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -205,27 +205,5 @@ TEST_F(TestOneHotInfo, CheckStrategy1) {
|
|||
Status ret = onehot_info->Init(strategy, nullptr);
|
||||
ASSERT_EQ(ret, FAILED);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{64}, {}, {}}
|
||||
/// Description: the in strategy is {{8}, {}, {}}
|
||||
/// Expectation: the return strategy is {{8, 1}, {}, {}}
|
||||
TEST_F(TestOneHotInfo, GenerateFullStrategy1) {
|
||||
Strategies in_strategy = {{8}, {}, {}};
|
||||
Strategies ret = onehot_info->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{8, 1}, {}, {}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{64}, {}, {}}
|
||||
/// Description: the in strategy is {{}, {}, {}}
|
||||
/// Expectation: the return strategy is {{8, 1}, {}, {}}
|
||||
TEST_F(TestOneHotInfo, GenerateFullStrategy2) {
|
||||
Strategies in_strategy = {{}, {}, {}};
|
||||
Strategies ret = onehot_info->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{8, 1}, {}, {}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -272,49 +272,5 @@ TEST_F(TestPReLUInfo, AutoStrategy_2d1) {
|
|||
ASSERT_EQ(stra1[0], 1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{64, 4, 8, 16}, {4}}
|
||||
/// Description: the in strategy is {{2, 4, 8, 16}, {}}
|
||||
/// Expectation: the return strategy is {{2, 4, 8, 16}, {4}}
|
||||
TEST_F(TestPReLUInfo, GenerateFullStrategy1) {
|
||||
Strategies in_strategy = {{2, 4, 8, 16}, {}};
|
||||
Strategies ret = prelu->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 8, 16}, {4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{64, 4, 8, 16}, {4}}
|
||||
/// Description: the in strategy is {{}, {4}}
|
||||
/// Expectation: the return strategy is {{1, 4, 1, 1}, {4}}
|
||||
TEST_F(TestPReLUInfo, GenerateFullStrategy2) {
|
||||
Strategies in_strategy = {{}, {4}};
|
||||
Strategies ret = prelu->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 4, 1, 1}, {4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{1024, 4}, {4}}
|
||||
/// Description: the in strategy is {{8, 4}, {}}
|
||||
/// Expectation: the return strategy is {{8, 4}, {4}}
|
||||
TEST_F(TestPReLUInfo, GenerateFullStrategy3) {
|
||||
Strategies in_strategy = {{8, 4}, {}};
|
||||
Strategies ret = prelu_2d->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{8, 4}, {4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for inputs_shape: {{1024, 4}, {4}}
|
||||
/// Description: the in strategy is {{}, {4}}
|
||||
/// Expectation: the return strategy is {{1, 4}, {4}}
|
||||
TEST_F(TestPReLUInfo, GenerateFullStrategy4) {
|
||||
Strategies in_strategy = {{}, {4}};
|
||||
Strategies ret = prelu_2d->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 4}, {4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue