!44962 fix bug for conv2d check strategy

Merge pull request !44962 from yangzhenzhang/fix-bug-for-conv2d-check-strategy
This commit is contained in:
i-robot 2022-11-02 01:50:03 +00:00 committed by Gitee
commit 90f7506eb4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 28 additions and 0 deletions

View File

@ -1154,12 +1154,25 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
std::vector<Dimensions> stra = strategy->GetInputDim();
Dimensions input_strategy = stra[0];
Dimensions weight_strategy = stra[1];
if (input_strategy.size() != 4 || weight_strategy.size() != 4) {
MS_LOG(ERROR) << name_
<< ": The size of input strategy or weight strategy must be 4, but the size of input strategy is "
<< input_strategy.size() << ", the size of weight strategy is " << weight_strategy.size();
return FAILED;
}
if (input_strategy[1] != weight_strategy[0]) {
MS_LOG(ERROR) << name_ << ": The shard num of c-out for input strategy is " << input_strategy[1]
<< ", but the shard num of c-out for weight strategy is " << weight_strategy[0];
return FAILED;
}
if (weight_strategy[2] != 1 || weight_strategy[3] != 1) {
MS_LOG(ERROR) << name_ << ": The kernel size can not be split, but the strategy for kernel size is ("
<< weight_strategy[2] << ", " << weight_strategy[3] << ")";
return FAILED;
}
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
return FAILED;

View File

@ -233,6 +233,21 @@ def test_conv2d_transpose_overlap_size_too_large():
compile_net(net)
def test_conv2d_transpose_split_kernel():
"""
Feature: the kernel size can not be split
Description: split the kernel size
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2))
strategy2 = ((8, 1, 1, 1),)
net = Net2(_w3, out_channel=8, kernel_size=(10, 10), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_conv2d_transpose_pad_mode_no_need_exchange():
"""
Feature: pad mode, and two direction send, w = 8, o = 16, s = 2, k = 1, n = 8, pad = (0, 0, 0, 0)