!36056 generate data parallel strategy for conv2d when the auto parallel search mode is dynamic programming
Merge pull request !36056 from yangzhenzhang/gen-data-parallel-strategy-for-conv2d-auto-parallel
This commit is contained in:
commit
57b33197b2
|
@ -27,6 +27,7 @@
|
|||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -928,6 +929,19 @@ void Conv2DInfo::ReComputeBatchSplitFlagList() {
|
|||
Status Conv2DInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
auto parallel_context = ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto parallel_mode = parallel_context->parallel_mode();
|
||||
auto search_mode = parallel_context->strategy_search_mode();
|
||||
// generate data parallel strategy when the search mode is not sharding propagation
|
||||
if (parallel_mode == parallel::kAutoParallel && search_mode != parallel::kShardingPropagation) {
|
||||
Strategys strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}};
|
||||
StrategyPtr data_parallel_sp = std::make_shared<Strategy>(stage_id, strategy);
|
||||
sp_vector.push_back(data_parallel_sp);
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
// to generate the strategy for (N, C1, H, W, C2), the k1/k2 can not be split
|
||||
Shapes splittable_input = {{1, 1, 1, 1, 1}};
|
||||
Shape tmp_shape = inputs_shape_[0];
|
||||
|
@ -937,7 +951,6 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
tmp_shape.push_back(inputs_shape_[1][1]); // the tmp shape is (N, C-out, H, W, C-in)
|
||||
}
|
||||
Shapes tmp_inputs_shape = {tmp_shape};
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
|
||||
}
|
||||
|
|
|
@ -106,3 +106,26 @@ def test_sharding_propagation_1x1x1x8():
|
|||
elif re.search("Conv2DTranspose", k) is not None:
|
||||
assert v == [[1, 1, 1, 8], [1, 1, 1, 1]]
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_dynamic_programming_1x1x1x8():
|
||||
"""
|
||||
Features: test dynamic programming for conv2d/bn/maxpool/conv2d_transpose
|
||||
Description: the fixed strategy is 1x1x1x8
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,
|
||||
search_mode="dynamic_programming")
|
||||
strategy = ((1, 1, 1, 8),)
|
||||
net = Net(_w1, _w2, out_channel=8, strategy=strategy)
|
||||
strategies = compile_net(net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search("Conv2D", k) is not None:
|
||||
assert v == [[8, 1, 1, 1], [1, 1, 1, 1]]
|
||||
elif re.search("BatchNorm", k) is not None:
|
||||
assert v == [[8, 1, 1, 1], [1], [1], [1], [1]]
|
||||
elif re.search("MaxPool", k) is not None:
|
||||
assert v == [[8, 1, 1, 1],]
|
||||
elif re.search("Conv2DTranspose", k) is not None:
|
||||
assert v == [[8, 1, 1, 1], [1, 1, 1, 1]]
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue