!22108 modify check strategy for conv2d

Merge pull request !22108 from yangzhenzhang/modify-check-strategy-for-conv2d
This commit is contained in:
i-robot 2021-08-24 12:01:23 +00:00 committed by Gitee
commit 1e47ff7bc3
3 changed files with 85 additions and 14 deletions

View File

@ -103,10 +103,6 @@ Status Conv2DInfo::GetAttrsBase() {
// group
group_ = GetIntAttr(GROUP);
if (group_ != 1) {
MS_LOG(ERROR) << name_ << ": The group must be 1, but got " << group_;
return FAILED;
}
// format
format_ = GetStringAttr(FORMAT);
@ -176,10 +172,10 @@ Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strateg
return FAILED;
}
if (w_slice_shape < ((kernel_size_[1] - stride_[3] + 1) / 2)) {
if (w_slice_shape <= ((kernel_size_[1] - stride_[3] + 1) / 2)) {
MS_LOG(ERROR) << name_
<< ": The 'same' mode do not support to split W when kernel_size > stride but w slice shape is "
"smaller than (k - s + 1) / 2";
"smaller than or equal to (k - s + 1) / 2";
return FAILED;
}
@ -275,6 +271,23 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
new_out_channel_ = out_channel_;
}
int64_t input_except_n_shards =
std::accumulate(input_strategy.begin() + 1, input_strategy.end(), 1, std::multiplies<int64_t>());
int64_t weight_shards =
std::accumulate(weight_strategy.begin() + 1, weight_strategy.end(), 1, std::multiplies<int64_t>());
bool is_data_parallel = (input_except_n_shards * weight_shards == 1);
if (!is_data_parallel) {
if (std::any_of(dilation_.begin(), dilation_.end(), [](int64_t value) { return value != 1; })) {
MS_LOG(ERROR) << name_ << ": If it is not data parallel, the value of dilation must be 1, but got " << dilation_;
return FAILED;
}
if (group_ != 1) {
MS_LOG(ERROR) << name_ << ": If it is not data parallel, the group must be 1, but got " << group_;
return FAILED;
}
}
return SUCCESS;
}
@ -536,17 +549,17 @@ void Conv2DInfo::InferSendRecvFlag() {
<< right_need_recv_;
if (left_need_send_) {
if (left_rank_overlap_right_size_ > input_slice_shape_[3]) {
if (left_rank_overlap_right_size_ >= input_slice_shape_[3]) {
MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << left_rank_overlap_right_size_
<< ") larger than slice shape in w dimension(" << input_slice_shape_[3] << ")";
<< ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
}
send_rank_ids_.push_back(left_rank_id_);
}
if (right_need_send_) {
if (right_rank_overlap_left_size_ > input_slice_shape_[3]) {
if (right_rank_overlap_left_size_ >= input_slice_shape_[3]) {
MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << right_rank_overlap_left_size_
<< ") larger than slice shape in w dimension(" << input_slice_shape_[3] << ")";
<< ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
}
send_rank_ids_.push_back(right_rank_id_);
}
@ -862,8 +875,8 @@ Status Conv2DBackpropInputInfo::GetOutShape() {
for (auto &element : elements) {
MS_EXCEPTION_IF_NULL(element);
if (element->isa<Int64Imm>()) {
int64_t axis = element->cast<Int64ImmPtr>()->value();
out_shape_.push_back(axis);
int64_t ele_value = element->cast<Int64ImmPtr>()->value();
out_shape_.push_back(ele_value);
} else {
MS_LOG(ERROR) << name_ << ": The value of shape must be int";
return FAILED;

View File

@ -23,11 +23,11 @@ from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1,
strategy1=None, strategy2=None):
super().__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
pad_mode=pad_mode, stride=stride, dilation=dilation, group=group).shard(strategy1)
self.neg = P.Neg().shard(strategy2)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
@ -43,6 +43,7 @@ _w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
_w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
_w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
@ -63,6 +64,24 @@ def test_conv2d_data_parallel():
compile_net(net)
def test_conv2d_data_parallel_dilation():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_conv2d_data_parallel_group():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_conv2d_model_parallel1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
@ -71,6 +90,26 @@ def test_conv2d_model_parallel1():
compile_net(net)
def test_conv2d_model_parallel_dilation():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_conv2d_model_parallel_group():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_conv2d_model_parallel2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
@ -182,6 +221,15 @@ def test_kernel_size_larger_than_stride_and_slice_too_small():
compile_net(net)
def test_conv2d_same_mode_overlap_size_equal_to_slice_shape():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((2, 1, 1, 4),)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_kernel_size_larger_than_stride_and_left_pad_is_0():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1))

View File

@ -122,6 +122,16 @@ def test_conv2d_transpose_overlap_size_too_large():
compile_net(net)
def test_conv2d_transpose_overlap_size_too_large2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_conv2d_transpose_rank0_no_need_overlap():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))