add sparse bprop

This commit is contained in:
yanglf1121 2022-04-06 16:53:42 +08:00
parent 4524185778
commit b950e05d47
10 changed files with 402 additions and 65 deletions

View File

@ -295,7 +295,6 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.updatestate_pure_node_eliminater_,
irpass.load_eliminater_,
irpass.stopgrad_eliminater_,
irpass.sparse_tensor_eliminate_,
});
}

View File

@ -1914,7 +1914,7 @@ def filter_(fun, iter_):
def csr_astype(x, dtype):
"""Implementation of `astype` for CSRTensor."""
data = F.cast(x.values, dtype)
data = x.values.astype(dtype)
return F.make_csr_tensor(x.indptr, x.indices, data, x.shape)
@ -1931,7 +1931,7 @@ def csr_abs(x):
def csr_mv(x, dense_vector):
"""Implementation of `abs` for CSRTensor."""
check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
check_value_type('dense_vector', dense_vector, (Tensor,), 'CSRTensor.mv')
return F.csr_mv(x, dense_vector)
@ -1943,7 +1943,7 @@ def csr_to_tuple(x):
def coo_astype(x, dtype):
"""Implementation of `astype` for COOTensor."""
data = F.cast(x.values, dtype)
data = x.values.astype(dtype)
return F.make_coo_tensor(x.indices, data, x.shape)

View File

@ -2590,7 +2590,7 @@ class COOTensor(COOTensor_):
Return a copy of the COOTensor, cast its values to a specified type.
Args:
dtype (:class:`mindspore.dtype`): Designated tensor dtype.
dtype (Union[:class:`mindspore.dtype`, numpy.dtype, str]): Designated tensor dtype.
Returns:
COOTensor.
@ -2798,7 +2798,7 @@ class CSRTensor(CSRTensor_):
Return a copy of the CSRTensor, cast its values to a specified type.
Args:
dtype (:class:`mindspore.dtype`): Designated tensor dtype.
dtype (Union[:class:`mindspore.dtype`, numpy.dtype, str]): Designated tensor dtype.
Returns:
CSRTensor.

View File

@ -19,43 +19,37 @@ from .. import functional as F
from .. import operations as P
from ..operations import _csr_ops
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..composite.multitype_ops._constexpr_utils import infer_out_shape
from .grad_base import bprops, bprop_getters
# Unused parameters are placeholders.
@bprops.register("MakeCSRTensor")
def bprop_make_csr_tensor(indptr, indices, values, dense_shape, out, dout):
"""Backpropagator for primitive `MakeCSRTensor`."""
res = (zeros_like(indptr), zeros_like(indices), F.csr_tensor_get_values(dout), ())
return res
# COOTensor Bprop Methods
@bprops.register("MakeCOOTensor")
def bprop_make_coo_tensor(indices, values, dense_shape, out, dout):
"""Backpropagator for primitive `MakeCOOTensor`."""
return zeros_like(indices), F.coo_tensor_get_values(dout), ()
return (zeros_like(indices), dout.values,)
@bprops.register("COOTensorGetIndices")
def bprop_sparse_tensor_get_indices(sparse_tensor, out, dout):
def bprop_coo_tensor_get_indices(coo_tensor, out, dout):
"""Backpropagator for primitive `COOTensorGetIndices`."""
return (zeros_like(sparse_tensor),)
return (F.make_coo_tensor(dout, zeros_like(coo_tensor.values), coo_tensor.shape),)
@bprops.register("COOTensorGetValues")
def bprop_sparse_tensor_get_values(sparse_tensor, out, dout):
def bprop_coo_tensor_get_values(coo_tensor, out, dout):
"""Backpropagator for primitive `COOTensorGetValues`."""
return F.make_coo_tensor(F.coo_tensor_get_indices(sparse_tensor),
dout,
F.coo_tensor_get_dense_shape(sparse_tensor))
return (F.make_coo_tensor(zeros_like(coo_tensor.indices), dout, coo_tensor.shape),)
@bprops.register("COOTensorrGetDenseShape")
def bprop_sparse_tensor_get_dense_shape(sparse_tensor, out, dout):
@bprops.register("COOTensorGetDenseShape")
def bprop_coo_tensor_get_dense_shape(coo_tensor, out, dout):
"""Backpropagator for primitive `COOTensorGetDenseShape`."""
return (zeros_like(sparse_tensor),)
return (zeros_like(coo_tensor),)
@bprop_getters.register(P.SparseToDense)
@ -89,6 +83,40 @@ def get_bprop_sparse_tensor_dense_matmul(self):
return bprop
# CSRTensor Bprop Methods
@bprops.register("MakeCSRTensor")
def bprop_make_csr_tensor(indptr, indices, values, dense_shape, out, dout):
"""Backpropagator for primitive `MakeCSRTensor`."""
res = (zeros_like(indptr), zeros_like(indices), dout.values, dout.shape)
return res
@bprops.register("CSRTensorGetIndptr")
def bprop_csr_tensor_get_indptr(csr_tensor, out, dout):
"""Backpropagator for primitive `CSRTensorGetIndptr`."""
return (F.make_csr_tensor(dout, zeros_like(csr_tensor.indices), zeros_like(csr_tensor.values), csr_tensor.shape),)
@bprops.register("CSRTensorGetIndices")
def bprop_csr_tensor_get_indices(csr_tensor, out, dout):
"""Backpropagator for primitive `CSRTensorGetIndices`."""
return (F.make_csr_tensor(zeros_like(csr_tensor.indptr), dout, zeros_like(csr_tensor.values), csr_tensor.shape),)
@bprops.register("CSRTensorGetValues")
def bprop_csr_tensor_get_values(csr_tensor, out, dout):
"""Backpropagator for primitive `CSRTensorGetValues`."""
return (F.make_csr_tensor(zeros_like(csr_tensor.indptr), zeros_like(csr_tensor.indices), dout, csr_tensor.shape),)
@bprops.register("CSRTensorGetDenseShape")
def bprop_csr_tensor_get_dense_shape(csr_tensor, out, dout):
"""Backpropagator for primitive `CSRTensorGetDenseShape`."""
return (zeros_like(csr_tensor),)
@bprop_getters.register(_csr_ops.CSRReduceSum)
def get_bprop_csr_reduce_sum(self):
"Back-propagation for CSRReduceSum."
@ -134,7 +162,14 @@ def get_bprop_csr_mv(self):
@bprop_getters.register(_csr_ops.CSRMul)
def get_bprop_csr_mul(self):
"Back-propagation for CSRMul."
"""
Back-propagation for CSRMul.
Note: Broadcast of first dimension of the dense input is not supported for `CSRDiv`,
because this would require sparse reduce sum on the first axis, which is not logically contiguous
for the CSR storage format. If broadcast of first dimension should be desired, the operator `/`
could be used instead, which bypass the constraint by making use of the indices in the CSR input
to index the dense input.
"""
def bprop(csr_tensor, dense, out, dout):
indptr = csr_tensor.indptr
indices = csr_tensor.indices
@ -146,8 +181,9 @@ def get_bprop_csr_mul(self):
dense_grad_value = F.mul(dout, values)
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
if len(dense.shape) == 1 or dense.shape[0] == 1:
dense_grad = F.csr_reduce_sum(dense_grad, 0)
elif dense.shape[1] == 1:
raise ValueError(
"Backpropagation for CSRMul with broadcast for the first dimension is not supported! Use `*` instead")
if dense.shape[1] == 1:
dense_grad = F.csr_reduce_sum(dense_grad, 1)
else:
row = F.csr2coo(indptr, indices.shape[0])
@ -157,15 +193,61 @@ def get_bprop_csr_mul(self):
return bprop
@bprop_getters.register(_csr_ops.CSRDiv)
def get_bprop_csr_div(self):
"""
Back-propagation for CSRDiv.
Note: Broadcast of first dimension of the dense input is not supported for `CSRDiv`,
because this would require sparse reduce sum on the first axis, which is not logically contiguous
for the CSR storage format. If broadcast of first dimension should be desired, the operator `/`
could be used instead, which bypass the constraint by making use of the indices in the CSR input
to index the dense input.
"""
def bprop(csr_tensor, dense, out, dout):
indptr = csr_tensor.indptr
indices = csr_tensor.indices
shape = csr_tensor.shape
batch_dim_csr_start = 2
batch_dim_dense_start = len(dense.shape) - (len(shape) - batch_dim_csr_start)
if batch_dim_dense_start < 0:
batch_dim_dense_start = 0
feature_dim = infer_out_shape(shape[:batch_dim_csr_start], dense.shape[:batch_dim_dense_start])
shape_x = feature_dim + shape[batch_dim_csr_start:]
shape_y = feature_dim + shape[batch_dim_dense_start:]
reduce_x, reduce_y = F.broadcast_gradient_args(shape_x, shape_y)
csr_tensor_grad_value = F.csr_div(F.make_csr_tensor(indptr, indices, dout, shape), dense)
if reduce_x:
csr_tensor_grad_value = P.ReduceSum(True)(csr_tensor_grad_value, reduce_x)
csr_tensor_grad = F.make_csr_tensor(indptr, indices, csr_tensor_grad_value, shape)
dense_grad_value = F.neg_tensor(F.mul(out, csr_tensor_grad_value))
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
if len(dense.shape) == 1 or dense.shape[0] == 1:
raise ValueError(
"Backpropagation for CSRDiv with broadcast for the first dimension is not supported! Use `/` instead")
if dense.shape[1] == 1:
dense_grad = F.csr_reduce_sum(dense_grad, 1)
else:
row = F.csr2coo(indptr, indices.shape[0])
coo_idx = P.Stack(-1)((row, indices))
dense_grad = F.tensor_scatter_update(zeros_like(dense), coo_idx, dense_grad_value)
if reduce_y:
dense_grad = P.ReduceSum(True)(csr_tensor_grad_value, reduce_y)
return csr_tensor_grad, dense_grad
return bprop
@bprop_getters.register(_csr_ops.CSR2COO)
def get_bprop_csr2coo(self):
def bprop(indptr, nnz, out, dout):
return zeros_like(dout)
return zeros_like(indptr), zeros_like(nnz)
return bprop
@bprop_getters.register(_csr_ops.COO2CSR)
def get_bprop_coo2csr(self):
def bprop(row_indices, height, out, dout):
return zeros_like(dout)
return zeros_like(row_indices), zeros_like(height)
return bprop

View File

@ -361,4 +361,19 @@ def _add_csrtensor_csrtensor(x, y):
"""
return F.make_csr_tensor(x.indptr, x.indices, x.values + y.values, x.shape)
@_add_backward.register("COOTensor", "COOTensor")
def _add_cootensor_cootensor(x, y):
"""
Adds COOTensor and COOTensor.
Args:
x (COOTensor): x
y (COOTensor): y
Returns:
COOTensor.
"""
return F.make_coo_tensor(x.indices, x.values + y.values, x.shape)
hyper_add = base.HyperMap(_add_backward)

View File

@ -1026,7 +1026,7 @@ make_csr_tensor = Primitive('MakeCSRTensor')
csr_tensor_get_values = Primitive('CSRTensorGetValues')
csr_tensor_get_indices = Primitive('CSRTensorGetIndices')
csr_tensor_get_indptr = Primitive('CSRTensorGetIndptr')
csr_tensor_get_shape = Primitive('CSRTensorGetDenseShape')
csr_tensor_get_dense_shape = Primitive('CSRTensorGetDenseShape')
tensor_operator_registry.register('all', P.ReduceAll)
tensor_operator_registry.register('any', P.ReduceAny)

View File

@ -0,0 +1,21 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""setup for pytest in mindspore.sparse"""
import mindspore.context as context
# pylint: disable=unused-argument
def setup_module(module):
context.set_context(mode=context.GRAPH_MODE)

View File

@ -0,0 +1,26 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""common utils for sparse tests"""
import platform
def get_platform():
return platform.system().lower()
def compare_res(tensor_tup, numpy_tup):
assert len(tensor_tup) == len(numpy_tup)
for item in zip(tensor_tup, numpy_tup):
assert (item[0].asnumpy() == item[1]).all()

View File

@ -14,19 +14,15 @@
# ============================================================================
"""smoke tests for COO operations"""
import platform
import pytest
import numpy as np
from mindspore import Tensor, COOTensor, ms_function, nn, context
from mindspore import Tensor, COOTensor, ms_function, nn, ops
from mindspore.common import dtype as mstype
from mindspore.ops import functional as F
from .sparse_utils import get_platform, compare_res
context.set_context(mode=context.GRAPH_MODE)
def get_platform():
return platform.system().lower()
def compare_coo(coo1, coo2):
assert isinstance(coo1, COOTensor)
@ -219,3 +215,96 @@ def test_coo_attr():
assert (py_tuple[i].asnumpy() == g_tuple[i].asnumpy()).all()
else:
assert py_tuple[i] == g_tuple[i]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_coo_bprop():
"""
Feature: Test back-propagation with COO-related Ops.
Description: Test back-propagation of make_coo, coo.attributes, coo.methods().
Expectation: Success.
"""
if get_platform() != "linux":
return
grad_op = ops.GradOperation(get_all=True)
indices = Tensor([[0, 1], [1, 2]], dtype=mstype.int32)
values = Tensor([-1, 2], dtype=mstype.float32)
dense_shape = (3, 4)
@grad_op
@ms_function
def test_coo_tensor(indices, values, dense_shape):
coo_tensor = COOTensor(indices, values, dense_shape)
return coo_tensor
@grad_op
@ms_function
def test_coo_indices(indices, values, dense_shape):
coo_tensor = COOTensor(indices, values, dense_shape)
return coo_tensor.indices
@grad_op
@ms_function
def test_coo_values(indices, values, dense_shape):
coo_tensor = COOTensor(indices, values, dense_shape)
return coo_tensor.values
@grad_op
@ms_function
def test_coo_shape(indices, values, dense_shape):
coo_tensor = COOTensor(indices, values, dense_shape)
return coo_tensor.shape
@grad_op
@ms_function
def test_coo_cast(indices, values, dense_shape):
coo_tensor = COOTensor(indices, values, dense_shape)
return coo_tensor.astype(mstype.int32)
@grad_op
@ms_function
def test_coo_dtype(indices, values, dense_shape):
coo_tensor = COOTensor(indices, values, dense_shape)
return coo_tensor.dtype
@grad_op
@ms_function
def test_coo_to_tuple(indices, values, dense_shape):
coo_tensor = COOTensor(indices, values, dense_shape)
return coo_tensor.to_tuple()
@grad_op
@ms_function
def test_coo_to_abs(indices, values, dense_shape):
coo_tensor = COOTensor(indices, values, dense_shape)
return coo_tensor.abs()
@grad_op
@ms_function
def test_coo_to_csr(indices, values, dense_shape):
coo_tensor = COOTensor(indices, values, dense_shape)
return coo_tensor.to_csr()
@grad_op
@ms_function
def test_coo_to_dense(indices, values, dense_shape):
coo_tensor = COOTensor(indices, values, dense_shape)
return coo_tensor.to_dense()
all_zero = (np.zeros(indices.shape, np.int32), np.zeros(values.shape, np.float32))
values_on = (np.zeros(indices.shape, np.int32), np.ones(values.shape, np.float32))
values_absgrad = (np.zeros(indices.shape, np.int32), np.sign(values.asnumpy()))
compare_res(test_coo_tensor(indices, values, dense_shape), values_on)
compare_res(test_coo_indices(indices, values, dense_shape), all_zero)
compare_res(test_coo_values(indices, values, dense_shape), values_on)
compare_res(test_coo_shape(indices, values, dense_shape), all_zero)
compare_res(test_coo_cast(indices, values, dense_shape), values_on)
compare_res(test_coo_dtype(indices, values, dense_shape), all_zero)
compare_res(test_coo_to_tuple(indices, values, dense_shape), values_on)
compare_res(test_coo_to_abs(indices, values, dense_shape), values_absgrad)
compare_res(test_coo_to_csr(indices, values, dense_shape), values_on)
compare_res(test_coo_to_dense(indices, values, dense_shape), values_on)

View File

@ -15,18 +15,17 @@
"""smoke tests for CSR operations"""
import os
import platform
import pytest
import numpy as np
from mindspore import Tensor, CSRTensor, ms_function, nn, context, ops
from mindspore import Tensor, CSRTensor, ms_function, nn, ops
from mindspore.ops.operations import _csr_ops
from mindspore.common import dtype as mstype
from mindspore.train.serialization import export, load
from mindspore.ops import functional as F
from .sparse_utils import get_platform, compare_res
context.set_context(mode=context.GRAPH_MODE)
def compare_csr(csr1, csr2):
assert isinstance(csr1, CSRTensor)
@ -36,8 +35,6 @@ def compare_csr(csr1, csr2):
assert (csr1.values.asnumpy() == csr2.values.asnumpy()).all()
assert csr1.shape == csr2.shape
def get_platform():
return platform.system().lower()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@ -494,10 +491,10 @@ def test_dtype_csr_tensor():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_csr_bprop():
def test_bprop():
"""
Feature: Test back-propagation with CSR-related Ops.
Description: Test CSRReduceSum, CSRMul, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense().
Description: Test CSRReduceSum, CSRMul, CSRDiv, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense().
Expectation: Success.
"""
if get_platform() != "linux":
@ -506,22 +503,26 @@ def test_csr_bprop():
csrmv = _csr_ops.CSRMV()
grad_op = ops.GradOperation(get_all=True)
@grad_op
@ms_function
def test_csr_mul(csr_tensor, dense):
return csr_tensor * dense
@grad_op
@ms_function
def test_csr_div(csr_tensor, dense):
return csr_tensor / dense
@grad_op
@ms_function
def test_csr_reduce_sum(csr_tensor, axis):
return csr_reduce_sum(csr_tensor, axis)
@grad_op
@ms_function
def test_csrmv(csr_tensor, dense):
return csrmv(csr_tensor, dense)
test_csr_mul_grad_pynative = grad_op(test_csr_mul)
test_csr_mul_grad_graph = ms_function(test_csr_mul_grad_pynative)
test_csr_reduce_sum_grad_pynative = grad_op(test_csr_reduce_sum)
test_csr_reduce_sum_grad_graph = ms_function(test_csr_reduce_sum_grad_pynative)
test_csrmv_grad_pynative = grad_op(test_csrmv)
test_csrmv_grad_graph = ms_function(test_csrmv_grad_pynative)
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
values = Tensor(np.arange(6), dtype=mstype.float32)
@ -531,39 +532,43 @@ def test_csr_bprop():
csr_mv_arg = Tensor([[1], [2], [3], [4]], dtype=mstype.float32)
csr_mv_expect_1 = np.array([4, 1, 2, 3, 2, 4], dtype=np.float32)
csr_mv_expect_2 = np.array([[1], [6], [3], [5]], dtype=np.float32)
csr_mv_output_1, csr_mv_output_2 = test_csrmv_grad_pynative(csr_tensor, csr_mv_arg)
assert np.allclose(csr_mv_output_1.values.asnumpy(), csr_mv_expect_1)
assert np.allclose(csr_mv_output_2.asnumpy(), csr_mv_expect_2)
csr_mv_output_1, csr_mv_output_2 = test_csrmv_grad_graph(csr_tensor, csr_mv_arg)
csr_mv_output_1, csr_mv_output_2 = test_csrmv(csr_tensor, csr_mv_arg)
assert np.allclose(csr_mv_output_1.values.asnumpy(), csr_mv_expect_1)
assert np.allclose(csr_mv_output_2.asnumpy(), csr_mv_expect_2)
csr_reduce_sum_expect = np.ones(6, dtype=np.float32)
csr_reduce_sum_output = test_csr_reduce_sum_grad_pynative(csr_tensor, 1)
assert np.allclose(csr_reduce_sum_output[0].values.asnumpy(), csr_reduce_sum_expect)
csr_reduce_sum_output = test_csr_reduce_sum_grad_graph(csr_tensor, 1)
assert np.allclose(csr_reduce_sum_output[0].values.asnumpy(), csr_reduce_sum_expect)
csr_reduce_sum_expect_1 = np.ones(6, dtype=np.float32)
csr_reduce_sum_output_1 = test_csr_reduce_sum(csr_tensor, 1)
assert np.allclose(csr_reduce_sum_output_1[0].values.asnumpy(), csr_reduce_sum_expect_1)
csr_mul_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32)
csr_mul_expect_1_1 = np.array([1, 2, 2, 2, 3, 3], dtype=np.float32)
csr_mul_expect_1_2 = np.array([[0], [6], [9]], dtype=np.float32)
csr_mul_output_1_1, csr_mul_output_1_2 = test_csr_mul_grad_pynative(csr_tensor, csr_mul_arg_1)
assert np.allclose(csr_mul_output_1_1.values.asnumpy(), csr_mul_expect_1_1)
assert np.allclose(csr_mul_output_1_2.asnumpy(), csr_mul_expect_1_2)
csr_mul_output_1_1, csr_mul_output_1_2 = test_csr_mul_grad_graph(csr_tensor, csr_mul_arg_1)
csr_mul_output_1_1, csr_mul_output_1_2 = test_csr_mul(csr_tensor, csr_mul_arg_1)
assert np.allclose(csr_mul_output_1_1.values.asnumpy(), csr_mul_expect_1_1)
assert np.allclose(csr_mul_output_1_2.asnumpy(), csr_mul_expect_1_2)
csr_mul_arg_2 = Tensor(np.arange(12).reshape(3, 4), dtype=mstype.float32)
csr_mul_expect_2_1 = np.array([3, 4, 5, 6, 9, 11], dtype=np.float32)
csr_mul_expect_2_2 = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
csr_mul_output_2_1, csr_mul_output_2_2 = test_csr_mul_grad_pynative(csr_tensor, csr_mul_arg_2)
assert np.allclose(csr_mul_output_2_1.values.asnumpy(), csr_mul_expect_2_1)
assert np.allclose(csr_mul_output_2_2.asnumpy(), csr_mul_expect_2_2)
csr_mul_output_2_1, csr_mul_output_2_2 = test_csr_mul_grad_graph(csr_tensor, csr_mul_arg_2)
csr_mul_output_2_1, csr_mul_output_2_2 = test_csr_mul(csr_tensor, csr_mul_arg_2)
assert np.allclose(csr_mul_output_2_1.values.asnumpy(), csr_mul_expect_2_1)
assert np.allclose(csr_mul_output_2_2.asnumpy(), csr_mul_expect_2_2)
csr_div_expect_1_1 = np.array([1, 0.5, 0.5, 0.5, 0.3333333, 0.3333333], dtype=np.float32)
csr_div_expect_1_2 = np.array([[0], [-1.5], [-1]], dtype=np.float32)
csr_div_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32)
csr_div_output_1_1, csr_div_output_1_2 = test_csr_div(csr_tensor, csr_div_arg_1)
assert np.allclose(csr_div_output_1_1.values.asnumpy(), csr_div_expect_1_1)
assert np.allclose(csr_div_output_1_2.asnumpy(), csr_div_expect_1_2)
csr_div_arg_2 = Tensor(np.arange(1, 13).reshape(3, 4), dtype=mstype.float32)
csr_div_expect_2_1 = np.array([0.25, 0.2, 0.16666667, 0.14285715, 0.1, 0.0833333], dtype=np.float32)
csr_div_expect_2_2 = np.array(
[[0, 0, 0, 0], [-0.04, -0.05555556, -0.06122449, 0], [0, -0.04, 0, -0.03472222]], dtype=np.float32)
csr_div_output_2_1, csr_div_output_2_2 = test_csr_div(csr_tensor, csr_div_arg_2)
assert np.allclose(csr_div_output_2_1.values.asnumpy(), csr_div_expect_2_1)
assert np.allclose(csr_div_output_2_2.asnumpy(), csr_div_expect_2_2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -600,3 +605,103 @@ def test_csr_method():
to_dense_output = CSRToDenseNet()(csr_tensor)
to_dense_expect = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_bprop2():
"""
Feature: Test back-propagation with CSR-related Ops.
Description: Test back-propagation of make_csr, csr.attributes, csr.methods().
Expectation: Success.
"""
if get_platform() != "linux":
return
grad_op = ops.GradOperation(get_all=True)
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
values = Tensor(np.arange(6) - 3.5, dtype=mstype.float32)
dense_shape = (3, 4)
@grad_op
@ms_function
def test_csr_tensor(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return csr_tensor
@grad_op
@ms_function
def test_csr_indptr(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return csr_tensor.indptr
@grad_op
@ms_function
def test_csr_indices(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return csr_tensor.indices
@grad_op
@ms_function
def test_csr_values(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return csr_tensor.values
@grad_op
@ms_function
def test_csr_shape(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return csr_tensor.shape
@grad_op
@ms_function
def test_csr_cast(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return csr_tensor.astype(mstype.int32)
@grad_op
@ms_function
def test_csr_dtype(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return csr_tensor.dtype
@grad_op
@ms_function
def test_csr_to_tuple(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return csr_tensor.to_tuple()
@grad_op
@ms_function
def test_csr_to_abs(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return csr_tensor.abs()
@grad_op
@ms_function
def test_csr_to_coo(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return csr_tensor.to_coo()
@grad_op
@ms_function
def test_csr_to_dense(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return csr_tensor.to_dense()
all_zero = (np.zeros(indptr.shape, np.int32), np.zeros(indices.shape, np.int32), np.zeros(values.shape, np.float32))
values_on = (np.zeros(indptr.shape, np.int32), np.zeros(indices.shape, np.int32), np.ones(values.shape, np.float32))
values_absgrad = (np.zeros(indptr.shape, np.int32), np.zeros(indices.shape, np.int32), np.sign(values.asnumpy()))
compare_res(test_csr_tensor(indptr, indices, values, dense_shape), values_on)
compare_res(test_csr_indptr(indptr, indices, values, dense_shape), all_zero)
compare_res(test_csr_indices(indptr, indices, values, dense_shape), all_zero)
compare_res(test_csr_values(indptr, indices, values, dense_shape), values_on)
compare_res(test_csr_cast(indptr, indices, values, dense_shape), values_on)
compare_res(test_csr_shape(indptr, indices, values, dense_shape), all_zero)
compare_res(test_csr_dtype(indptr, indices, values, dense_shape), all_zero)
compare_res(test_csr_to_tuple(indptr, indices, values, dense_shape), values_on)
compare_res(test_csr_to_abs(indptr, indices, values, dense_shape), values_absgrad)
compare_res(test_csr_to_coo(indptr, indices, values, dense_shape), values_on)
compare_res(test_csr_to_dense(indptr, indices, values, dense_shape), values_on)