forked from mindspore-Ecosystem/mindspore
support group for conv2d
This commit is contained in:
parent
76af1f861f
commit
5514189257
|
@ -273,19 +273,21 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
|
|||
int64_t input_except_n_shards =
|
||||
std::accumulate(input_strategy.begin() + 1, input_strategy.end(), 1, std::multiplies<int64_t>());
|
||||
int64_t weight_shards =
|
||||
std::accumulate(weight_strategy.begin() + 1, weight_strategy.end(), 1, std::multiplies<int64_t>());
|
||||
std::accumulate(weight_strategy.begin(), weight_strategy.end(), 1, std::multiplies<int64_t>());
|
||||
|
||||
bool is_data_parallel = (input_except_n_shards * weight_shards == 1);
|
||||
if (!is_data_parallel) {
|
||||
if (std::any_of(dilation_.begin(), dilation_.end(), [](int64_t value) { return value != 1; })) {
|
||||
MS_LOG(ERROR) << name_ << ": If it is not data parallel, the value of dilation must be 1, but got " << dilation_;
|
||||
MS_LOG(ERROR) << name_ << ": It is not data parallel, the value of dilation must be 1, but got " << dilation_;
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (group_ != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": If it is not data parallel, the group must be 1, but got " << group_;
|
||||
return FAILED;
|
||||
}
|
||||
if (group_ != 1 && (weight_strategy[0] != 1 || weight_strategy[1] != 1)) {
|
||||
MS_LOG(ERROR) << name_ << ": The group is " << group_
|
||||
<< ", the cout and cin can not be split, but the shard num of cout is " << weight_strategy[0]
|
||||
<< ", the shard num of cin is " << weight_strategy[1];
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -143,7 +143,7 @@ def test_conv2d_model_parallel_dilation():
|
|||
def test_conv2d_model_parallel_group():
|
||||
"""
|
||||
Feature: test conv2d model parallel and group is not 1
|
||||
Description: model parallel and group is not 1
|
||||
Description: split cin and cout, and group is not 1
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
|
@ -155,6 +155,20 @@ def test_conv2d_model_parallel_group():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_model_parallel_group2():
|
||||
"""
|
||||
Feature: test conv2d model parallel and group is not 1
|
||||
Description: has not to split cin and cout, and group is not 1
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_model_parallel2():
|
||||
"""
|
||||
Feature: same mode, stride = kernel_size, no need exchange
|
||||
|
|
|
@ -41,11 +41,11 @@ class Net(Cell):
|
|||
|
||||
|
||||
class Net2(Cell):
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, group=1,
|
||||
strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
pad_mode=pad_mode, stride=stride, group=group).shard(strategy1)
|
||||
self.neg = P.Neg().shard(strategy2)
|
||||
self.weight = Parameter(conv2d_weight, "w1")
|
||||
|
||||
|
@ -60,6 +60,7 @@ _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
|||
_w2 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
|
||||
_w3 = Tensor(np.ones([8, 16, 10, 10]), dtype=ms.float32)
|
||||
_w4 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
|
||||
_w5 = Tensor(np.ones([8, 8, 4, 4]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
|
@ -85,6 +86,20 @@ def test_conv2d_transpose_data_parallel():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_transpose_group():
|
||||
"""
|
||||
Feature: test group is not 1
|
||||
Description: shard n/h/w, and group is 2
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net2(_w5, out_channel=8, kernel_size=4, pad_mode="same", stride=2, group=2, strategy1=strategy1,
|
||||
strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_transpose_model_parallel1():
|
||||
"""
|
||||
Feature: test model parallel strategy
|
||||
|
|
Loading…
Reference in New Issue