forked from mindspore-Ecosystem/mindspore
check platform for resizebilinear
This commit is contained in:
parent
5deccfe64b
commit
43e6e16da3
|
@ -69,7 +69,14 @@ Status ResizeBilinearInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (input_strategy[3] != 1) {
|
||||
if (backend == kGPUDevice) {
|
||||
MS_LOG(ERROR) << name_ << ": Do not support split W dimension in GPU platform";
|
||||
return FAILED;
|
||||
}
|
||||
need_exchange_overlap_ = true;
|
||||
MS_LOG(INFO) << name_ << ": Split the w dimension";
|
||||
}
|
||||
|
|
|
@ -253,3 +253,19 @@ def test_bilinear_shard_n_c_w():
|
|||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_resizebilinear_shard_W_in_GPU():
|
||||
"""
|
||||
Feature: test ResizeBilinear
|
||||
Description: the platform is GPU, and shard n/c/w
|
||||
Expectation: compile failed, can not shard h or w dimension in GPU
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=3)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 2, 1, 2),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue