From b63e70f6a992f8ffb3a98fba199fd2f19f062b52 Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Thu, 16 Jun 2022 14:54:18 +0800 Subject: [PATCH] gen data parallel strategy for conv2d auto parallel --- .../frontend/parallel/ops_info/conv2d_info.cc | 15 +++++++++++- .../test_sharding_propagation_for_conv2ds.py | 23 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc index f2bed223e34..f45a6917f9f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc @@ -27,6 +27,7 @@ #include "frontend/parallel/strategy.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h" #include "frontend/parallel/graph_util/generate_graph.h" +#include "include/common/utils/parallel_context.h" #include "pipeline/jit/resource.h" namespace mindspore { @@ -928,6 +929,19 @@ void Conv2DInfo::ReComputeBatchSplitFlagList() { Status Conv2DInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } std::vector Conv2DInfo::GenerateOpStrategies(int64_t stage_id) { + std::vector sp_vector; + auto parallel_context = ParallelContext::GetInstance(); + MS_EXCEPTION_IF_NULL(parallel_context); + auto parallel_mode = parallel_context->parallel_mode(); + auto search_mode = parallel_context->strategy_search_mode(); + // generate data parallel strategy when the search mode is not sharding propagation + if (parallel_mode == parallel::kAutoParallel && search_mode != parallel::kShardingPropagation) { + Strategys strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}}; + StrategyPtr data_parallel_sp = std::make_shared(stage_id, strategy); + sp_vector.push_back(data_parallel_sp); + return sp_vector; + } + // to generate the strategy for (N, C1, H, W, C2), the k1/k2 can not be split Shapes splittable_input = {{1, 1, 1, 1, 1}}; Shape tmp_shape = inputs_shape_[0]; @@ -937,7 +951,6 @@ std::vector Conv2DInfo::GenerateOpStrategies(int64_t stage_id) { tmp_shape.push_back(inputs_shape_[1][1]); // the tmp shape is (N, C-out, H, W, C-in) } Shapes tmp_inputs_shape = {tmp_shape}; - std::vector sp_vector; if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) { MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed"; } diff --git a/tests/ut/python/parallel/test_sharding_propagation_for_conv2ds.py b/tests/ut/python/parallel/test_sharding_propagation_for_conv2ds.py index 5bf4bfbcd8f..6fcee716a60 100644 --- a/tests/ut/python/parallel/test_sharding_propagation_for_conv2ds.py +++ b/tests/ut/python/parallel/test_sharding_propagation_for_conv2ds.py @@ -106,3 +106,26 @@ def test_sharding_propagation_1x1x1x8(): elif re.search("Conv2DTranspose", k) is not None: assert v == [[1, 1, 1, 8], [1, 1, 1, 1]] context.reset_auto_parallel_context() + + +def test_dynamic_programming_1x1x1x8(): + """ + Features: test dynamic programming for conv2d/bn/maxpool/conv2d_transpose + Description: the fixed strategy is 1x1x1x8 + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, + search_mode="dynamic_programming") + strategy = ((1, 1, 1, 8),) + net = Net(_w1, _w2, out_channel=8, strategy=strategy) + strategies = compile_net(net) + for (k, v) in strategies.items(): + if re.search("Conv2D", k) is not None: + assert v == [[8, 1, 1, 1], [1, 1, 1, 1]] + elif re.search("BatchNorm", k) is not None: + assert v == [[8, 1, 1, 1], [1], [1], [1], [1]] + elif re.search("MaxPool", k) is not None: + assert v == [[8, 1, 1, 1],] + elif re.search("Conv2DTranspose", k) is not None: + assert v == [[8, 1, 1, 1], [1, 1, 1, 1]] + context.reset_auto_parallel_context()