forked from mindspore-Ecosystem/mindspore
!22108 modify check strategy for conv2d
Merge pull request !22108 from yangzhenzhang/modify-check-strategy-for-conv2d
This commit is contained in:
commit
1e47ff7bc3
|
@ -103,10 +103,6 @@ Status Conv2DInfo::GetAttrsBase() {
|
|||
|
||||
// group
|
||||
group_ = GetIntAttr(GROUP);
|
||||
if (group_ != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The group must be 1, but got " << group_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// format
|
||||
format_ = GetStringAttr(FORMAT);
|
||||
|
@ -176,10 +172,10 @@ Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strateg
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (w_slice_shape < ((kernel_size_[1] - stride_[3] + 1) / 2)) {
|
||||
if (w_slice_shape <= ((kernel_size_[1] - stride_[3] + 1) / 2)) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The 'same' mode do not support to split W when kernel_size > stride but w slice shape is "
|
||||
"smaller than (k - s + 1) / 2";
|
||||
"smaller than or equal to (k - s + 1) / 2";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -275,6 +271,23 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
|
|||
new_out_channel_ = out_channel_;
|
||||
}
|
||||
|
||||
int64_t input_except_n_shards =
|
||||
std::accumulate(input_strategy.begin() + 1, input_strategy.end(), 1, std::multiplies<int64_t>());
|
||||
int64_t weight_shards =
|
||||
std::accumulate(weight_strategy.begin() + 1, weight_strategy.end(), 1, std::multiplies<int64_t>());
|
||||
|
||||
bool is_data_parallel = (input_except_n_shards * weight_shards == 1);
|
||||
if (!is_data_parallel) {
|
||||
if (std::any_of(dilation_.begin(), dilation_.end(), [](int64_t value) { return value != 1; })) {
|
||||
MS_LOG(ERROR) << name_ << ": If it is not data parallel, the value of dilation must be 1, but got " << dilation_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (group_ != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": If it is not data parallel, the group must be 1, but got " << group_;
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -536,17 +549,17 @@ void Conv2DInfo::InferSendRecvFlag() {
|
|||
<< right_need_recv_;
|
||||
|
||||
if (left_need_send_) {
|
||||
if (left_rank_overlap_right_size_ > input_slice_shape_[3]) {
|
||||
if (left_rank_overlap_right_size_ >= input_slice_shape_[3]) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << left_rank_overlap_right_size_
|
||||
<< ") larger than slice shape in w dimension(" << input_slice_shape_[3] << ")";
|
||||
<< ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
|
||||
}
|
||||
send_rank_ids_.push_back(left_rank_id_);
|
||||
}
|
||||
|
||||
if (right_need_send_) {
|
||||
if (right_rank_overlap_left_size_ > input_slice_shape_[3]) {
|
||||
if (right_rank_overlap_left_size_ >= input_slice_shape_[3]) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << right_rank_overlap_left_size_
|
||||
<< ") larger than slice shape in w dimension(" << input_slice_shape_[3] << ")";
|
||||
<< ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
|
||||
}
|
||||
send_rank_ids_.push_back(right_rank_id_);
|
||||
}
|
||||
|
@ -862,8 +875,8 @@ Status Conv2DBackpropInputInfo::GetOutShape() {
|
|||
for (auto &element : elements) {
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
if (element->isa<Int64Imm>()) {
|
||||
int64_t axis = element->cast<Int64ImmPtr>()->value();
|
||||
out_shape_.push_back(axis);
|
||||
int64_t ele_value = element->cast<Int64ImmPtr>()->value();
|
||||
out_shape_.push_back(ele_value);
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": The value of shape must be int";
|
||||
return FAILED;
|
||||
|
|
|
@ -23,11 +23,11 @@ from mindspore.ops import operations as P
|
|||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1,
|
||||
strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
pad_mode=pad_mode, stride=stride, dilation=dilation, group=group).shard(strategy1)
|
||||
self.neg = P.Neg().shard(strategy2)
|
||||
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||
|
||||
|
@ -43,6 +43,7 @@ _w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
|
|||
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
|
||||
_w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
|
||||
_w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
|
@ -63,6 +64,24 @@ def test_conv2d_data_parallel():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_data_parallel_dilation():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_data_parallel_group():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_model_parallel1():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
|
@ -71,6 +90,26 @@ def test_conv2d_model_parallel1():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_model_parallel_dilation():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_model_parallel_group():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_model_parallel2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
|
||||
strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
|
||||
|
@ -182,6 +221,15 @@ def test_kernel_size_larger_than_stride_and_slice_too_small():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_same_mode_overlap_size_equal_to_slice_shape():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
|
||||
strategy2 = ((2, 1, 1, 4),)
|
||||
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_kernel_size_larger_than_stride_and_left_pad_is_0():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1))
|
||||
|
|
|
@ -122,6 +122,16 @@ def test_conv2d_transpose_overlap_size_too_large():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_transpose_overlap_size_too_large2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
|
||||
strategy2 = ((2, 2, 1, 4),)
|
||||
net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_transpose_rank0_no_need_overlap():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
|
||||
|
|
Loading…
Reference in New Issue