!40309 Fix vmap test fail of deformable_conv on windows
Merge pull request !40309 from YuJianfeng/master
This commit is contained in:
commit
984a92c0b8
|
@ -22,6 +22,7 @@ from mindspore import Tensor
|
|||
import mindspore.ops as ops
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import nn_ops as NN
|
||||
|
||||
context.set_context(device_target='CPU')
|
||||
|
||||
|
@ -341,20 +342,20 @@ def test_vmap():
|
|||
"""
|
||||
kh, kw = 3, 3
|
||||
|
||||
def cal_deformable_conv2d(x, weight, offsets):
|
||||
return ops.deformable_conv2d(x, weight, offsets, (kh, kw), (1, 1, 1, 1), (0, 0, 0, 0))
|
||||
def cal_deformable_offsets(x, offsets):
|
||||
deformable_offsets = NN.DeformableOffsets((1, 1, 1, 1), (0, 0, 0, 0), (kh, kw))
|
||||
return deformable_offsets(x, offsets)
|
||||
|
||||
x = Tensor(np.arange(2 * 2 * 3 * 5 * 5).reshape(2, 2, 3, 5, 5), mstype.float32)
|
||||
weight = Tensor(np.arange(5 * 3 * kh * kw).reshape(5, 3, kh, kw), mstype.float32)
|
||||
offsets = Tensor(np.ones((2, 2, 3 * kh * kw, 3, 3)), mstype.float32)
|
||||
vmap_deformable_conv2d = F.vmap(cal_deformable_conv2d, in_axes=(0, None, 0), out_axes=0)
|
||||
out1 = vmap_deformable_conv2d(x, weight, offsets)
|
||||
vmap_deformable_offsets = F.vmap(cal_deformable_offsets, in_axes=(0, 0), out_axes=0)
|
||||
out1 = vmap_deformable_offsets(x, offsets)
|
||||
|
||||
def manually_batched(x, weight, offsets):
|
||||
def manually_batched(x, offsets):
|
||||
output = []
|
||||
for i in range(x.shape[0]):
|
||||
output.append(cal_deformable_conv2d(x[i], weight, offsets[i]))
|
||||
output.append(cal_deformable_offsets(x[i], offsets[i]))
|
||||
return F.stack(output)
|
||||
|
||||
out2 = manually_batched(x, weight, offsets)
|
||||
out2 = manually_batched(x, offsets)
|
||||
assert np.allclose(out1.asnumpy(), out2.asnumpy())
|
||||
|
|
Loading…
Reference in New Issue