forked from mindspore-Ecosystem/mindspore
temp fix csrtensor testcase
This commit is contained in:
parent
a3695771aa
commit
bfabc2b89d
|
@ -92,7 +92,6 @@ def test_csr_attr():
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_tensor_in_while():
|
||||
"""
|
||||
|
@ -145,3 +144,59 @@ def test_csr_tensor_in_while():
|
|||
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
|
||||
assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0)
|
||||
assert shape == out.shape
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_tensor_in_while_cpu():
|
||||
"""
|
||||
Feature: Test CSRTensor in while loop.
|
||||
Description: Test CSRTensor computation in while loop.
|
||||
Expectation: Success.
|
||||
"""
|
||||
class CSRTensorValuesDouble(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
indptr = x.indptr
|
||||
indices = x.indices
|
||||
values = x.values * 2
|
||||
shape = x.shape
|
||||
return CSRTensor(indptr, indices, values, shape)
|
||||
|
||||
class CSRTensorValuesAdd2(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
indptr = x.indptr
|
||||
indices = x.indices
|
||||
values = x.values + 2
|
||||
shape = x.shape
|
||||
return CSRTensor(indptr, indices, values, shape)
|
||||
|
||||
class CSRTensorWithControlWhile(nn.Cell):
|
||||
def __init__(self, shape):
|
||||
super().__init__()
|
||||
self.op1 = CSRTensorValuesDouble()
|
||||
self.op2 = CSRTensorValuesAdd2()
|
||||
self.shape = shape
|
||||
|
||||
@ms_function
|
||||
def construct(self, a, b, indptr, indices, values):
|
||||
x = CSRTensor(indptr, indices, values, self.shape)
|
||||
x = self.op2(x)
|
||||
while a > b:
|
||||
x = self.op1(x)
|
||||
b = b + 1
|
||||
return x
|
||||
a = Tensor(3, mstype.int32)
|
||||
b = Tensor(0, mstype.int32)
|
||||
indptr = Tensor([0, 1, 2])
|
||||
indices = Tensor([0, 1])
|
||||
values = Tensor([1, 2], dtype=mstype.float32)
|
||||
shape = (2, 6)
|
||||
net = CSRTensorWithControlWhile(shape)
|
||||
out = net(a, b, indptr, indices, values)
|
||||
assert np.allclose(out.indptr.asnumpy(), indptr.asnumpy(), .0, .0)
|
||||
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
|
||||
assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0)
|
||||
assert shape == out.shape
|
||||
|
|
Loading…
Reference in New Issue