!29265 fix resize_bilinear infer

Merge pull request !29265 from jiangzhenguang/resize_bilinear
This commit is contained in:
i-robot 2022-01-20 06:51:03 +00:00 committed by Gitee
commit 7bb5819889
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 15 additions and 1 deletions

View File

@ -2339,7 +2339,7 @@ class ParallelResizeBilinearGrad(PrimitiveWithInfer):
x_dtype = x['dtype']
validator.check_tensor_dtype_valid("grad_dtype", grad_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32], self.name)
if size_val is not None:
if size_val is None:
raise ValueError("size should be const input")
output_shape = [grad_shape[0], grad_shape[1], x_shape[2], x_shape[3]]

View File

@ -239,3 +239,17 @@ def test_neighbor_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)
def test_bilinear_shard_n_c_w():
"""
Feature: test ResizeBilinear shard n/c/w
Description: shard n/c/w
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=3)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 2, 1, 2),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)