!45779 Rectify the case adaptation error.

Merge pull request !45779 from Margaret_wangrui/vmap_parameter
This commit is contained in:
i-robot 2022-11-22 02:33:19 +00:00 committed by Gitee
commit 9f96e3db9f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 2 additions and 2 deletions

View File

@ -135,7 +135,7 @@ def test_vmap_gradient():
assert np.allclose(vmap_jvp_y.asnumpy(), expect_y_jvp.asnumpy(), 0.0001, 0.0001)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@ -179,7 +179,7 @@ def test_vmap_monad():
def construct(self, assign_add_val, scatter_indices, scatter_updates):
output = vmap(self.net, (0, 1, 0, 0, None), 1)(assign_add_val, self.assign_add_var,
self.scatter_ref, scatter_indices, scatter_updates)
return output, self.assign_add_var.value()
return output, self.assign_add_var
assign_add_val = Tensor([[[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]]], mstype.float32)
scatter_indices = Tensor([[[0, 1], [1, 1]], [[0, 1], [0, 1]], [[1, 1], [1, 0]]], mstype.int32)