forked from mindspore-Ecosystem/mindspore
!29265 fix resize_bilinear infer
Merge pull request !29265 from jiangzhenguang/resize_bilinear
This commit is contained in:
commit
7bb5819889
|
@ -2339,7 +2339,7 @@ class ParallelResizeBilinearGrad(PrimitiveWithInfer):
|
|||
x_dtype = x['dtype']
|
||||
validator.check_tensor_dtype_valid("grad_dtype", grad_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
if size_val is not None:
|
||||
if size_val is None:
|
||||
raise ValueError("size should be const input")
|
||||
output_shape = [grad_shape[0], grad_shape[1], x_shape[2], x_shape[3]]
|
||||
|
||||
|
|
|
@ -239,3 +239,17 @@ def test_neighbor_auto_parallel():
|
|||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_bilinear_shard_n_c_w():
|
||||
"""
|
||||
Feature: test ResizeBilinear shard n/c/w
|
||||
Description: shard n/c/w
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=3)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 2, 1, 2),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue