forked from mindspore-Ecosystem/mindspore
!45779 Rectify the case adaptation error.
Merge pull request !45779 from Margaret_wangrui/vmap_parameter
This commit is contained in:
commit
9f96e3db9f
|
@ -135,7 +135,7 @@ def test_vmap_gradient():
|
|||
assert np.allclose(vmap_jvp_y.asnumpy(), expect_y_jvp.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -179,7 +179,7 @@ def test_vmap_monad():
|
|||
def construct(self, assign_add_val, scatter_indices, scatter_updates):
|
||||
output = vmap(self.net, (0, 1, 0, 0, None), 1)(assign_add_val, self.assign_add_var,
|
||||
self.scatter_ref, scatter_indices, scatter_updates)
|
||||
return output, self.assign_add_var.value()
|
||||
return output, self.assign_add_var
|
||||
|
||||
assign_add_val = Tensor([[[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]]], mstype.float32)
|
||||
scatter_indices = Tensor([[[0, 1], [1, 1]], [[0, 1], [0, 1]], [[1, 1], [1, 0]]], mstype.int32)
|
||||
|
|
Loading…
Reference in New Issue