check platform for resizebilinear

This commit is contained in:
yangzhenzhang 2022-02-25 15:50:26 +08:00
parent 5deccfe64b
commit 43e6e16da3
2 changed files with 23 additions and 0 deletions

View File

@ -69,7 +69,14 @@ Status ResizeBilinearInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (input_strategy[3] != 1) { if (input_strategy[3] != 1) {
if (backend == kGPUDevice) {
MS_LOG(ERROR) << name_ << ": Do not support split W dimension in GPU platform";
return FAILED;
}
need_exchange_overlap_ = true; need_exchange_overlap_ = true;
MS_LOG(INFO) << name_ << ": Split the w dimension"; MS_LOG(INFO) << name_ << ": Split the w dimension";
} }

View File

@ -253,3 +253,19 @@ def test_bilinear_shard_n_c_w():
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2) strategy1=strategy1, strategy2=strategy2)
compile_net(net) compile_net(net)
def test_resizebilinear_shard_W_in_GPU():
"""
Feature: test ResizeBilinear
Description: the platform is GPU, and shard n/c/w
Expectation: compile failed, can not shard h or w dimension in GPU
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=3)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 2, 1, 2),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)