add ut for infer strategy
This commit is contained in:
parent
b2ad845781
commit
115ef88b7c
|
@ -337,11 +337,16 @@ std::vector<StrategyPtr> BatchNormInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
|
||||
// in_strategy: ((N, C, H, W), (), (), (), ()) return: ((N, C, H, W), (C), (C), (C), (C))
|
||||
// in_strategy: ((), (C), (C), (C), (C)) return: ((1, C, 1, 1), (C), (C), (C), (C))
|
||||
// in_strategy: ((), (C), (), (C), (C)) throw exception
|
||||
Shapes BatchNormInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 5) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in strategy must be 5, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if ((in_strategy[2] != in_strategy[1]) || (in_strategy[3] != in_strategy[1]) || (in_strategy[4] != in_strategy[1])) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The last 4 strategy must be equal, but got " << in_strategy;
|
||||
}
|
||||
|
||||
Shape channel_strategy;
|
||||
if (!in_strategy[0].empty()) {
|
||||
if (in_strategy[0].size() != 4 && in_strategy[0].size() != 2) {
|
||||
|
|
|
@ -281,10 +281,10 @@ Shapes LayerNormInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
|||
<< ": The size of in_strategy[0] must be equal to the size of inputs_shape[0], but the in_strategy[0] is"
|
||||
<< in_strategy[0] << ", the inputs_shape[0] is " << input_shape_;
|
||||
}
|
||||
if (inputs_shape_.size() == gamma_shape_.size()) {
|
||||
if (input_shape_.size() == gamma_shape_.size()) {
|
||||
return Shapes({in_strategy[0], in_strategy[0], in_strategy[0]});
|
||||
} else {
|
||||
size_t diff_len = inputs_shape_.size() - gamma_shape_.size();
|
||||
size_t diff_len = input_shape_.size() - gamma_shape_.size();
|
||||
Shape gamma_strategy(in_strategy[0].begin() + diff_len, in_strategy[0].end());
|
||||
return Shapes({in_strategy[0], gamma_strategy, gamma_strategy});
|
||||
}
|
||||
|
@ -297,10 +297,10 @@ Shapes LayerNormInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
|||
<< ": The size of in_strategy[1] must be equal to the size of inputs_shape[1], but the in_strategy[1] is"
|
||||
<< in_strategy[1] << ", the inputs_shape[1] is " << gamma_shape_;
|
||||
}
|
||||
if (inputs_shape_.size() == gamma_shape_.size()) {
|
||||
if (input_shape_.size() == gamma_shape_.size()) {
|
||||
return Shapes({in_strategy[1], in_strategy[1], in_strategy[1]});
|
||||
} else {
|
||||
size_t diff_len = inputs_shape_.size() - gamma_shape_.size();
|
||||
size_t diff_len = input_shape_.size() - gamma_shape_.size();
|
||||
Shape tmp_strategy = in_strategy[1];
|
||||
(void)tmp_strategy.insert(tmp_strategy.begin(), diff_len, 1);
|
||||
return Shapes({tmp_strategy, in_strategy[1], in_strategy[1]});
|
||||
|
|
|
@ -2126,15 +2126,6 @@ Shapes OperatorInfo::InferParamStrategy(const Shapes &default_strategy) {
|
|||
return default_strategy;
|
||||
}
|
||||
|
||||
bool HasEmptyStrategy(const Shapes &in_strategy) {
|
||||
for (auto &ele : in_strategy) {
|
||||
if (ele.empty()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// in_strategy: ((A, B, C, D), ()), return: ((A, B, C, D), (A, B, C, D))
|
||||
// in_strategy: ((), (A, B, C, D)), return: ((A, B, C, D), (A, B, C, D))
|
||||
Shapes OperatorInfo::InferStrategySameMode(const Shapes &in_strategy) {
|
||||
|
@ -2212,9 +2203,10 @@ Shapes OperatorInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
|||
MS_LOG(EXCEPTION) << name_ << ": The in strategy is " << in_strategy << ", need to override this function ";
|
||||
}
|
||||
|
||||
Shapes OperatorInfo::GenerateFullStrategy(const Shapes &in_strategy) {
|
||||
Shapes OperatorInfo::GenerateFullStrategyBase(const Shapes &in_strategy) {
|
||||
// there is no empty in the in_strategy
|
||||
if (!HasEmptyStrategy(in_strategy)) {
|
||||
auto it = std::find_if(in_strategy.begin(), in_strategy.end(), [](const Shape &ele) { return ele.empty(); });
|
||||
if (it == in_strategy.end()) {
|
||||
MS_LOG(INFO) << name_ << ": There is no empty in the input strategy, return to itself: " << in_strategy;
|
||||
return in_strategy;
|
||||
}
|
||||
|
@ -2230,10 +2222,6 @@ Shapes OperatorInfo::GenerateFullStrategy(const Shapes &in_strategy) {
|
|||
}
|
||||
|
||||
// generate the full strategy from non empty strategy
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
||||
}
|
||||
|
||||
Shapes ret;
|
||||
switch (infer_strategy_mode_) {
|
||||
case SAME_MODE:
|
||||
|
@ -2253,13 +2241,22 @@ Shapes OperatorInfo::GenerateFullStrategy(const Shapes &in_strategy) {
|
|||
MS_LOG(EXCEPTION) << name_ << ": The invalid mode for infer strategy";
|
||||
}
|
||||
|
||||
if (name_.find(ONEHOT_INFO) == std::string::npos && CheckStrategyBase(ret, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The origin strategy is " << in_strategy << ", and the return strategy is " << ret;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": The origin strategy is " << in_strategy << ", and the return strategy is " << ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
Shapes OperatorInfo::GenerateFullStrategy(const Shapes &in_strategy) {
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
||||
}
|
||||
Shapes ret = GenerateFullStrategyBase(in_strategy);
|
||||
StrategyPtr strategy_ptr = NewStrategy(0, ret);
|
||||
if (CheckStrategy(strategy_ptr) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Invalid strategy";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<ValuePtr> GetValueSequence(const ValuePtr &sequence) {
|
||||
MS_EXCEPTION_IF_NULL(sequence);
|
||||
std::vector<ValuePtr> ret;
|
||||
|
|
|
@ -238,6 +238,7 @@ class OperatorInfo {
|
|||
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
|
||||
virtual Shapes InferParamStrategy(const Shapes &default_strategy);
|
||||
virtual Shapes InferStrategyIndividualMode(const Shapes &in_strategy);
|
||||
Shapes GenerateFullStrategyBase(const Shapes &in_strategy);
|
||||
Shapes InferStrategySameMode(const Shapes &in_strategy);
|
||||
Shapes InferStrategyBroadcastMode(const Shapes &in_strategy);
|
||||
Shapes InferStrategyIndependentMode(const Shapes &in_strategy);
|
||||
|
|
|
@ -180,13 +180,22 @@ std::vector<StrategyPtr> ScatterUpdateInfo::GenerateOpStrategies(int64_t stage_i
|
|||
|
||||
// in_strategy: ((A, B, C, D), (), ()), Shapes: ((a, b, c, d), (e, f), (e, f, b, c, d))
|
||||
// return: ((A, B, C, D), (1, 1), (1, 1, B, C, D))
|
||||
// in_strategy: ((), (), (E, F, B, C, D)), Shapes: ((a, b, c, d), (e, f), (e, f, b, c, d))
|
||||
// in_strategy: ((), (1, 1), (E, F, B, C, D)), Shapes: ((a, b, c, d), (e, f), (e, f, b, c, d))
|
||||
// return: ((A, B, C, D), (1, 1), (E, F, B, C, D))
|
||||
// in_strategy: ((), (), (E, F, B, C, D)), Shapes: ((a, b, c, d), (e, f), (e, f, b, c, d))
|
||||
// return: throw exception
|
||||
Shapes ScatterUpdateInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if (in_strategy[1].empty() != in_strategy[2].empty()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< name_
|
||||
<< ": The in_strategy[1] and in_strategy[2] must be all empty or all non empty, but the in_strategy[1] is "
|
||||
<< in_strategy[1] << ", the in_strategy[2] is " << in_strategy[2];
|
||||
}
|
||||
|
||||
Shape x_strategy, indices_strategy, updates_strategy;
|
||||
indices_strategy = Shape(inputs_shape_[1].size(), 1);
|
||||
if (!in_strategy[0].empty()) {
|
||||
|
@ -197,19 +206,15 @@ Shapes ScatterUpdateInfo::InferStrategyIndividualMode(const Shapes &in_strategy)
|
|||
}
|
||||
|
||||
if (!in_strategy[2].empty()) {
|
||||
if (std::accumulate(in_strategy[1].begin(), in_strategy[1].end(), 1, std::multiplies<int64_t>()) != 1) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[1] must be fill with 1, but got " << in_strategy[1];
|
||||
}
|
||||
x_strategy = in_strategy[2];
|
||||
(void)x_strategy.erase(x_strategy.begin(),
|
||||
x_strategy.begin() + static_cast<different_type>(inputs_shape_[1].size() - 1));
|
||||
return Shapes({x_strategy, indices_strategy, in_strategy[2]});
|
||||
}
|
||||
|
||||
if (!in_strategy[1].empty()) {
|
||||
if (std::accumulate(in_strategy[1].begin(), in_strategy[1].end(), 1, std::multiplies<int64_t>()) != 1) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[1] must be fill with 1, but got " << in_strategy[1];
|
||||
}
|
||||
return Shapes({Shape(inputs_shape_[0].size(), 1), in_strategy[1], Shape(inputs_shape_[2].size(), 1)});
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0], in_strategy[1] and in_strategy[2] are empty";
|
||||
}
|
||||
|
||||
|
|
|
@ -62,24 +62,24 @@ void TestInferStrategyIndependentMode::SetUp() {
|
|||
}
|
||||
|
||||
/// Feature: infer strategy for independent mode
|
||||
/// Description: the in strategy is {{2, 4, 4}, {}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
|
||||
/// Expectation: the return strategy is {{2, 4, 4}, {1, 1, 1}}
|
||||
/// Description: the in strategy is {{1, 1, 1}, {}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
|
||||
/// Expectation: the return strategy is {{1, 1, 1}, {1, 1, 1}}
|
||||
TEST_F(TestInferStrategyIndependentMode, GenerateFullStrategy1) {
|
||||
Strategies in_strategy = {{2, 4, 4}, {}};
|
||||
Strategies in_strategy = {{1, 1, 1}, {}};
|
||||
Strategies ret = gathernd->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 4}, {1, 1, 1}};
|
||||
Strategies expect = {{1, 1, 1}, {1, 1, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for independent mode
|
||||
/// Description: the in strategy is {{}, {2, 4, 4}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
|
||||
/// Expectation: the return strategy is {{1, 1, 1}, {2, 4, 4}}
|
||||
/// Description: the in strategy is {{}, {2, 4, 1}}, the in shapes is {{32, 64, 96}, {32, 64, 96}}
|
||||
/// Expectation: the return strategy is {{1, 1, 1}, {2, 4, 1}}
|
||||
TEST_F(TestInferStrategyIndependentMode, GenerateFullStrategy2) {
|
||||
Strategies in_strategy = {{}, {2, 4, 4}};
|
||||
Strategies in_strategy = {{}, {2, 4, 1}};
|
||||
Strategies ret = gathernd->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 1, 1}, {2, 4, 4}};
|
||||
Strategies expect = {{1, 1, 1}, {2, 4, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
} // namespace parallel
|
||||
|
|
|
@ -0,0 +1,264 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/ops_info/layer_norm_info.h"
|
||||
#include "frontend/parallel/ops_info/batchnorm_info.h"
|
||||
#include "frontend/parallel/ops_info/bias_add_info.h"
|
||||
#include "frontend/parallel/ops_info/scatter_update_info.h"
|
||||
#include "frontend/parallel/ops_info/conv2d_info.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
||||
class LayerNormInfo;
|
||||
class BiasAddInfo;
|
||||
class BatchNormInfo;
|
||||
class ScatterUpdateInfo;
|
||||
class Conv2DInfo;
|
||||
using LayerNormInfoPtr = std::shared_ptr<LayerNormInfo>;
|
||||
using BiasAddInfoPtr = std::shared_ptr<BiasAddInfo>;
|
||||
using BatchNormInfoPtr = std::shared_ptr<BatchNormInfo>;
|
||||
using ScatterUpdateInfoPtr = std::shared_ptr<ScatterUpdateInfo>;
|
||||
using Conv2DInfoPtr = std::shared_ptr<Conv2DInfo>;
|
||||
LayerNormInfoPtr layer_norm;
|
||||
BiasAddInfoPtr bias_add;
|
||||
BatchNormInfoPtr batch_norm;
|
||||
ScatterUpdateInfoPtr scatter_update;
|
||||
Conv2DInfoPtr conv2d;
|
||||
|
||||
class TestInferStrategyIndividualMode : public UT::Common {
|
||||
public:
|
||||
TestInferStrategyIndividualMode() {}
|
||||
void SetUp();
|
||||
void TearDown() {}
|
||||
};
|
||||
|
||||
void TestInferStrategyIndividualMode::SetUp() {
|
||||
RankList dev_list;
|
||||
|
||||
for (int32_t i = 0; i < 64; i++) {
|
||||
dev_list.push_back(i);
|
||||
}
|
||||
|
||||
RankList stage_map;
|
||||
stage_map.push_back(64);
|
||||
|
||||
int32_t local_dev = 0;
|
||||
|
||||
// create a new g_device_manager
|
||||
g_device_manager = std::make_shared<DeviceManager>();
|
||||
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
|
||||
|
||||
// layer_norm
|
||||
ValuePtr begin_norm_axis = MakeValue(std::int64_t(3));
|
||||
mindspore::HashMap<std::string, ValuePtr> attr_1 = {{"begin_norm_axis", begin_norm_axis}};
|
||||
|
||||
Shapes ln_inputs_shape = {{16, 32, 64, 96}, {32, 64, 96}, {32, 64, 96}};
|
||||
Shapes ln_outputs_shape = {{16, 32, 64, 96}, {16, 32, 1, 1}, {16, 32, 1, 1}};
|
||||
layer_norm = std::make_shared<LayerNormInfo>("layernorm_info", ln_inputs_shape, ln_outputs_shape, attr_1);
|
||||
|
||||
// bias_add
|
||||
mindspore::HashMap<std::string, ValuePtr> attr_2;
|
||||
Shapes ba_inputs_shape = {{64, 96}, {96}};
|
||||
Shapes ba_outputs_shape = {{64, 96}};
|
||||
bias_add = std::make_shared<BiasAddInfo>("biasadd_info", ba_inputs_shape, ba_outputs_shape, attr_2);
|
||||
|
||||
// batch_norm
|
||||
ValuePtr is_training = MakeValue(true);
|
||||
ValuePtr epsilon = MakeValue(std::float_t(1.0));
|
||||
ValuePtr momentum = MakeValue(std::float_t(1.0));
|
||||
ValuePtr format = MakeValue("NCHW");
|
||||
mindspore::HashMap<std::string, ValuePtr> attr_3 = {{"is_training", is_training},
|
||||
{"epsilon", epsilon},
|
||||
{"momentum", momentum},
|
||||
{"format", format}};
|
||||
|
||||
Shapes bn_inputs_shape = {{64, 96, 32, 16}, {96}, {96}, {96}, {96}};
|
||||
Shapes bn_outputs_shape = {{64, 96, 32, 16}, {96}, {96}, {96}, {96}};
|
||||
batch_norm = std::make_shared<BatchNormInfo>("batchnorm_info", bn_inputs_shape, bn_outputs_shape, attr_3);
|
||||
|
||||
// scatter_update
|
||||
mindspore::HashMap<std::string, ValuePtr> attr_4;
|
||||
Shapes su_inputs_shape = {{16, 32, 64, 96}, {128, 256}, {128, 256, 32, 64, 96}};
|
||||
Shapes su_outputs_shape = {{16, 32, 64, 96}};
|
||||
scatter_update = std::make_shared<ScatterUpdateInfo>("scatterupdate_info", su_inputs_shape, su_outputs_shape, attr_4);
|
||||
|
||||
// conv2d
|
||||
ValuePtr out_channel = MakeValue(std::int64_t(10));
|
||||
ValuePtr kernel_size = MakeValue(std::vector<int64_t>{4, 4});
|
||||
ValuePtr mode = MakeValue(std::int64_t(1));
|
||||
ValuePtr pad_mode = MakeValue(std::int64_t(1));
|
||||
ValuePtr pad_list = MakeValue(std::vector<int64_t>{1, 1, 1, 1});
|
||||
ValuePtr stride = MakeValue(std::vector<int64_t>{1, 1, 2, 2});
|
||||
ValuePtr dilation = MakeValue(std::vector<int64_t>{1, 1, 1, 1});
|
||||
ValuePtr group = MakeValue(std::int64_t(1));
|
||||
mindspore::HashMap<std::string, ValuePtr> attr_5 = {{"out_channel", out_channel},
|
||||
{"kernel_size", kernel_size},
|
||||
{"pad_mode", pad_mode},
|
||||
{"mode", mode},
|
||||
{"pad_list", pad_list},
|
||||
{"stride", stride},
|
||||
{"dilation", dilation},
|
||||
{"group", group},
|
||||
{"format", format}};
|
||||
|
||||
Shapes conv_inputs_shape = {{128, 2, 16, 16}, {10, 2, 4, 4}};
|
||||
Shapes conv_outputs_shape = {{128, 10, 8, 8}};
|
||||
conv2d = std::make_shared<Conv2DInfo>("conv2d_info", conv_inputs_shape, conv_outputs_shape, attr_5);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for layer_norm
|
||||
/// Description: the in strategy is {{2, 4, 8, 1}, {}, {}}
|
||||
/// Expectation: the return strategy is {{2, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy1) {
|
||||
Strategies in_strategy = {{2, 4, 8, 1}, {}, {}};
|
||||
Strategies ret = layer_norm->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for layer_norm
|
||||
/// Description: the in strategy is {{}, {4, 8, 1}, {4, 8, 1}}
|
||||
/// Expectation: the return strategy is {{1, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy2) {
|
||||
Strategies in_strategy = {{}, {4, 8, 1}, {4, 8, 1}};
|
||||
Strategies ret = layer_norm->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 4, 8, 1}, {4, 8, 1}, {4, 8, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for layer_norm
|
||||
/// Description: the in strategy is {{}, {4, 8, 1}, {}}
|
||||
/// Expectation: throw exception
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy3) {
|
||||
Strategies in_strategy = {{}, {4, 8, 1}, {}};
|
||||
ASSERT_ANY_THROW(layer_norm->GenerateFullStrategy(in_strategy));
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for bias_add
|
||||
/// Description: the in strategy is {{4, 8}, {}}
|
||||
/// Expectation: the return strategy is {{4, 8}, {8}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy4) {
|
||||
Strategies in_strategy = {{4, 8}, {}};
|
||||
Strategies ret = bias_add->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{4, 8}, {8}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for bias_add
|
||||
/// Description: the in strategy is {{}, {8}}
|
||||
/// Expectation: the return strategy is {{1, 8}, {8}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy5) {
|
||||
Strategies in_strategy = {{}, {8}};
|
||||
Strategies ret = bias_add->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 8}, {8}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for batch_norm
|
||||
/// Description: the in strategy is {{2, 4, 8, 16}, {}, {}, {}, {}}
|
||||
/// Expectation: the return strategy is {{2, 4, 8, 16}, {4}, {4}, {4}, {4}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy6) {
|
||||
Strategies in_strategy = {{2, 4, 8, 16}, {}, {}, {}, {}};
|
||||
Strategies ret = batch_norm->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{2, 4, 8, 16}, {4}, {4}, {4}, {4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for batch_norm
|
||||
/// Description: the in strategy is {{}, {4}, {4}, {4}, {4}}
|
||||
/// Expectation: the return strategy is {{1, 4, 1, 1}, {4}, {4}, {4}, {4}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy7) {
|
||||
Strategies in_strategy = {{}, {4}, {4}, {4}, {4}};
|
||||
Strategies ret = batch_norm->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 4, 1, 1}, {4}, {4}, {4}, {4}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for batch_norm
|
||||
/// Description: the in strategy is {{}, {4}, {}, {}, {4}}
|
||||
/// Expectation: throw exception
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy8) {
|
||||
Strategies in_strategy = {{}, {4}, {}, {}, {4}};
|
||||
ASSERT_ANY_THROW(batch_norm->GenerateFullStrategy(in_strategy));
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for scatter_update
|
||||
/// Description: the in strategy is {{1, 4, 8, 1}, {}, {}}
|
||||
/// Expectation: the return strategy is {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy9) {
|
||||
Strategies in_strategy = {{1, 4, 8, 1}, {}, {}};
|
||||
Strategies ret = scatter_update->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for scatter_update
|
||||
/// Description: the in strategy is {{}, {1, 1}, {1, 1, 4, 8, 1}}
|
||||
/// Expectation: the return strategy is {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy10) {
|
||||
Strategies in_strategy = {{}, {1, 1}, {1, 1, 4, 8, 1}};
|
||||
Strategies ret = scatter_update->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 4, 8, 1}, {1, 1}, {1, 1, 4, 8, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for scatter_update
|
||||
/// Description: the in strategy is {{}, {1, 1}, {}}
|
||||
/// Expectation: throw exception
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy11) {
|
||||
Strategies in_strategy = {{}, {1, 1}, {}};
|
||||
ASSERT_ANY_THROW(scatter_update->GenerateFullStrategy(in_strategy));
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for conv2d
|
||||
/// Description: the in strategy is {{8, 2, 1, 1}, {}}
|
||||
/// Expectation: the return strategy is {{8, 2, 1, 1}, {1, 2, 1, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy12) {
|
||||
Strategies in_strategy = {{8, 2, 1, 1}, {}};
|
||||
Strategies ret = conv2d->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{8, 2, 1, 1}, {1, 2, 1, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
|
||||
/// Feature: infer strategy for conv2d
|
||||
/// Description: the in strategy is {{}, {1, 2, 1, 1}}
|
||||
/// Expectation: the return strategy is {{1, 2, 1, 1}, {1, 2, 1, 1}}
|
||||
TEST_F(TestInferStrategyIndividualMode, GenerateFullStrategy13) {
|
||||
Strategies in_strategy = {{}, {1, 2, 1, 1}};
|
||||
Strategies ret = conv2d->GenerateFullStrategy(in_strategy);
|
||||
|
||||
Strategies expect = {{1, 2, 1, 1}, {1, 2, 1, 1}};
|
||||
ASSERT_EQ(ret, expect);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue