temp fix csrtensor testcase

This commit is contained in:
yanglf1121 2021-12-03 17:19:02 +08:00
parent a3695771aa
commit bfabc2b89d
1 changed files with 56 additions and 1 deletions

View File

@ -92,7 +92,6 @@ def test_csr_attr():
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_csr_tensor_in_while():
"""
@ -145,3 +144,59 @@ def test_csr_tensor_in_while():
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0)
assert shape == out.shape
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_csr_tensor_in_while_cpu():
"""
Feature: Test CSRTensor in while loop.
Description: Test CSRTensor computation in while loop.
Expectation: Success.
"""
class CSRTensorValuesDouble(nn.Cell):
def construct(self, x):
indptr = x.indptr
indices = x.indices
values = x.values * 2
shape = x.shape
return CSRTensor(indptr, indices, values, shape)
class CSRTensorValuesAdd2(nn.Cell):
def construct(self, x):
indptr = x.indptr
indices = x.indices
values = x.values + 2
shape = x.shape
return CSRTensor(indptr, indices, values, shape)
class CSRTensorWithControlWhile(nn.Cell):
def __init__(self, shape):
super().__init__()
self.op1 = CSRTensorValuesDouble()
self.op2 = CSRTensorValuesAdd2()
self.shape = shape
@ms_function
def construct(self, a, b, indptr, indices, values):
x = CSRTensor(indptr, indices, values, self.shape)
x = self.op2(x)
while a > b:
x = self.op1(x)
b = b + 1
return x
a = Tensor(3, mstype.int32)
b = Tensor(0, mstype.int32)
indptr = Tensor([0, 1, 2])
indices = Tensor([0, 1])
values = Tensor([1, 2], dtype=mstype.float32)
shape = (2, 6)
net = CSRTensorWithControlWhile(shape)
out = net(a, b, indptr, indices, values)
assert np.allclose(out.indptr.asnumpy(), indptr.asnumpy(), .0, .0)
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0)
assert shape == out.shape