support group for conv2d

This commit is contained in:
yangzhenzhang 2022-01-18 11:39:50 +08:00
parent 76af1f861f
commit 5514189257
3 changed files with 40 additions and 9 deletions

View File

@ -273,19 +273,21 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
int64_t input_except_n_shards =
std::accumulate(input_strategy.begin() + 1, input_strategy.end(), 1, std::multiplies<int64_t>());
int64_t weight_shards =
std::accumulate(weight_strategy.begin() + 1, weight_strategy.end(), 1, std::multiplies<int64_t>());
std::accumulate(weight_strategy.begin(), weight_strategy.end(), 1, std::multiplies<int64_t>());
bool is_data_parallel = (input_except_n_shards * weight_shards == 1);
if (!is_data_parallel) {
if (std::any_of(dilation_.begin(), dilation_.end(), [](int64_t value) { return value != 1; })) {
MS_LOG(ERROR) << name_ << ": If it is not data parallel, the value of dilation must be 1, but got " << dilation_;
MS_LOG(ERROR) << name_ << ": It is not data parallel, the value of dilation must be 1, but got " << dilation_;
return FAILED;
}
}
if (group_ != 1) {
MS_LOG(ERROR) << name_ << ": If it is not data parallel, the group must be 1, but got " << group_;
return FAILED;
}
if (group_ != 1 && (weight_strategy[0] != 1 || weight_strategy[1] != 1)) {
MS_LOG(ERROR) << name_ << ": The group is " << group_
<< ", the cout and cin can not be split, but the shard num of cout is " << weight_strategy[0]
<< ", the shard num of cin is " << weight_strategy[1];
return FAILED;
}
return SUCCESS;
}

View File

@ -143,7 +143,7 @@ def test_conv2d_model_parallel_dilation():
def test_conv2d_model_parallel_group():
"""
Feature: test conv2d model parallel and group is not 1
Description: model parallel and group is not 1
Description: split cin and cout, and group is not 1
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
@ -155,6 +155,20 @@ def test_conv2d_model_parallel_group():
compile_net(net)
def test_conv2d_model_parallel_group2():
"""
Feature: test conv2d model parallel and group is not 1
Description: has not to split cin and cout, and group is not 1
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_conv2d_model_parallel2():
"""
Feature: same mode, stride = kernel_size, no need exchange

View File

@ -41,11 +41,11 @@ class Net(Cell):
class Net2(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, group=1,
strategy1=None, strategy2=None):
super().__init__()
self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
pad_mode=pad_mode, stride=stride, group=group).shard(strategy1)
self.neg = P.Neg().shard(strategy2)
self.weight = Parameter(conv2d_weight, "w1")
@ -60,6 +60,7 @@ _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_w2 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
_w3 = Tensor(np.ones([8, 16, 10, 10]), dtype=ms.float32)
_w4 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
_w5 = Tensor(np.ones([8, 8, 4, 4]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
@ -85,6 +86,20 @@ def test_conv2d_transpose_data_parallel():
compile_net(net)
def test_conv2d_transpose_group():
"""
Feature: test group is not 1
Description: shard n/h/w, and group is 2
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net2(_w5, out_channel=8, kernel_size=4, pad_mode="same", stride=2, group=2, strategy1=strategy1,
strategy2=strategy2)
compile_net(net)
def test_conv2d_transpose_model_parallel1():
"""
Feature: test model parallel strategy