!40309 Fix vmap test fail of deformable_conv on windows

Merge pull request !40309 from YuJianfeng/master
This commit is contained in:
i-robot 2022-08-15 02:12:43 +00:00 committed by Gitee
commit 984a92c0b8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 9 additions and 8 deletions

View File

@ -22,6 +22,7 @@ from mindspore import Tensor
import mindspore.ops as ops
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops.operations import nn_ops as NN
context.set_context(device_target='CPU')
@ -341,20 +342,20 @@ def test_vmap():
"""
kh, kw = 3, 3
def cal_deformable_conv2d(x, weight, offsets):
return ops.deformable_conv2d(x, weight, offsets, (kh, kw), (1, 1, 1, 1), (0, 0, 0, 0))
def cal_deformable_offsets(x, offsets):
deformable_offsets = NN.DeformableOffsets((1, 1, 1, 1), (0, 0, 0, 0), (kh, kw))
return deformable_offsets(x, offsets)
x = Tensor(np.arange(2 * 2 * 3 * 5 * 5).reshape(2, 2, 3, 5, 5), mstype.float32)
weight = Tensor(np.arange(5 * 3 * kh * kw).reshape(5, 3, kh, kw), mstype.float32)
offsets = Tensor(np.ones((2, 2, 3 * kh * kw, 3, 3)), mstype.float32)
vmap_deformable_conv2d = F.vmap(cal_deformable_conv2d, in_axes=(0, None, 0), out_axes=0)
out1 = vmap_deformable_conv2d(x, weight, offsets)
vmap_deformable_offsets = F.vmap(cal_deformable_offsets, in_axes=(0, 0), out_axes=0)
out1 = vmap_deformable_offsets(x, offsets)
def manually_batched(x, weight, offsets):
def manually_batched(x, offsets):
output = []
for i in range(x.shape[0]):
output.append(cal_deformable_conv2d(x[i], weight, offsets[i]))
output.append(cal_deformable_offsets(x[i], offsets[i]))
return F.stack(output)
out2 = manually_batched(x, weight, offsets)
out2 = manually_batched(x, offsets)
assert np.allclose(out1.asnumpy(), out2.asnumpy())