!36056 generate data parallel strategy for conv2d when the auto parallel search mode is dynamic programming

Merge pull request !36056 from yangzhenzhang/gen-data-parallel-strategy-for-conv2d-auto-parallel
This commit is contained in:
i-robot 2022-06-17 01:26:58 +00:00 committed by Gitee
commit 57b33197b2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 37 additions and 1 deletions

View File

@ -27,6 +27,7 @@
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "include/common/utils/parallel_context.h"
#include "pipeline/jit/resource.h"
namespace mindspore {
@ -928,6 +929,19 @@ void Conv2DInfo::ReComputeBatchSplitFlagList() {
Status Conv2DInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
std::vector<StrategyPtr> sp_vector;
auto parallel_context = ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
auto parallel_mode = parallel_context->parallel_mode();
auto search_mode = parallel_context->strategy_search_mode();
// generate data parallel strategy when the search mode is not sharding propagation
if (parallel_mode == parallel::kAutoParallel && search_mode != parallel::kShardingPropagation) {
Strategys strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}};
StrategyPtr data_parallel_sp = std::make_shared<Strategy>(stage_id, strategy);
sp_vector.push_back(data_parallel_sp);
return sp_vector;
}
// to generate the strategy for (N, C1, H, W, C2), the k1/k2 can not be split
Shapes splittable_input = {{1, 1, 1, 1, 1}};
Shape tmp_shape = inputs_shape_[0];
@ -937,7 +951,6 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
tmp_shape.push_back(inputs_shape_[1][1]); // the tmp shape is (N, C-out, H, W, C-in)
}
Shapes tmp_inputs_shape = {tmp_shape};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
}

View File

@ -106,3 +106,26 @@ def test_sharding_propagation_1x1x1x8():
elif re.search("Conv2DTranspose", k) is not None:
assert v == [[1, 1, 1, 8], [1, 1, 1, 1]]
context.reset_auto_parallel_context()
def test_dynamic_programming_1x1x1x8():
"""
Features: test dynamic programming for conv2d/bn/maxpool/conv2d_transpose
Description: the fixed strategy is 1x1x1x8
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,
search_mode="dynamic_programming")
strategy = ((1, 1, 1, 8),)
net = Net(_w1, _w2, out_channel=8, strategy=strategy)
strategies = compile_net(net)
for (k, v) in strategies.items():
if re.search("Conv2D", k) is not None:
assert v == [[8, 1, 1, 1], [1, 1, 1, 1]]
elif re.search("BatchNorm", k) is not None:
assert v == [[8, 1, 1, 1], [1], [1], [1], [1]]
elif re.search("MaxPool", k) is not None:
assert v == [[8, 1, 1, 1],]
elif re.search("Conv2DTranspose", k) is not None:
assert v == [[8, 1, 1, 1], [1, 1, 1, 1]]
context.reset_auto_parallel_context()