From bfabc2b89d18aab3c9c0dbc5a15fa513d8ff9c6c Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Fri, 3 Dec 2021 17:19:02 +0800 Subject: [PATCH] temp fix csrtensor testcase --- tests/st/sparse/test_csr.py | 57 ++++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/st/sparse/test_csr.py b/tests/st/sparse/test_csr.py index b023ac9958f..2a80e51bf21 100644 --- a/tests/st/sparse/test_csr.py +++ b/tests/st/sparse/test_csr.py @@ -92,7 +92,6 @@ def test_csr_attr(): @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_gpu_training -@pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_csr_tensor_in_while(): """ @@ -145,3 +144,59 @@ def test_csr_tensor_in_while(): assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0) assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0) assert shape == out.shape + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_csr_tensor_in_while_cpu(): + """ + Feature: Test CSRTensor in while loop. + Description: Test CSRTensor computation in while loop. + Expectation: Success. + """ + class CSRTensorValuesDouble(nn.Cell): + + def construct(self, x): + indptr = x.indptr + indices = x.indices + values = x.values * 2 + shape = x.shape + return CSRTensor(indptr, indices, values, shape) + + class CSRTensorValuesAdd2(nn.Cell): + + def construct(self, x): + indptr = x.indptr + indices = x.indices + values = x.values + 2 + shape = x.shape + return CSRTensor(indptr, indices, values, shape) + + class CSRTensorWithControlWhile(nn.Cell): + def __init__(self, shape): + super().__init__() + self.op1 = CSRTensorValuesDouble() + self.op2 = CSRTensorValuesAdd2() + self.shape = shape + + @ms_function + def construct(self, a, b, indptr, indices, values): + x = CSRTensor(indptr, indices, values, self.shape) + x = self.op2(x) + while a > b: + x = self.op1(x) + b = b + 1 + return x + a = Tensor(3, mstype.int32) + b = Tensor(0, mstype.int32) + indptr = Tensor([0, 1, 2]) + indices = Tensor([0, 1]) + values = Tensor([1, 2], dtype=mstype.float32) + shape = (2, 6) + net = CSRTensorWithControlWhile(shape) + out = net(a, b, indptr, indices, values) + assert np.allclose(out.indptr.asnumpy(), indptr.asnumpy(), .0, .0) + assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0) + assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0) + assert shape == out.shape