!44962 fix bug for conv2d check strategy
Merge pull request !44962 from yangzhenzhang/fix-bug-for-conv2d-check-strategy
This commit is contained in:
commit
90f7506eb4
|
@ -1154,12 +1154,25 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
Dimensions input_strategy = stra[0];
|
||||
Dimensions weight_strategy = stra[1];
|
||||
if (input_strategy.size() != 4 || weight_strategy.size() != 4) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The size of input strategy or weight strategy must be 4, but the size of input strategy is "
|
||||
<< input_strategy.size() << ", the size of weight strategy is " << weight_strategy.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[1] != weight_strategy[0]) {
|
||||
MS_LOG(ERROR) << name_ << ": The shard num of c-out for input strategy is " << input_strategy[1]
|
||||
<< ", but the shard num of c-out for weight strategy is " << weight_strategy[0];
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (weight_strategy[2] != 1 || weight_strategy[3] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The kernel size can not be split, but the strategy for kernel size is ("
|
||||
<< weight_strategy[2] << ", " << weight_strategy[3] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
|
||||
return FAILED;
|
||||
|
|
|
@ -233,6 +233,21 @@ def test_conv2d_transpose_overlap_size_too_large():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_transpose_split_kernel():
|
||||
"""
|
||||
Feature: the kernel size can not be split
|
||||
Description: split the kernel size
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net2(_w3, out_channel=8, kernel_size=(10, 10), pad_mode="same", stride=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_transpose_pad_mode_no_need_exchange():
|
||||
"""
|
||||
Feature: pad mode, and two direction send, w = 8, o = 16, s = 2, k = 1, n = 8, pad = (0, 0, 0, 0)
|
||||
|
|
Loading…
Reference in New Issue