forked from mindspore-Ecosystem/mindspore
add sparse bprop
This commit is contained in:
parent
4524185778
commit
b950e05d47
|
@ -295,7 +295,6 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.updatestate_pure_node_eliminater_,
|
||||
irpass.load_eliminater_,
|
||||
irpass.stopgrad_eliminater_,
|
||||
irpass.sparse_tensor_eliminate_,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -1914,7 +1914,7 @@ def filter_(fun, iter_):
|
|||
|
||||
def csr_astype(x, dtype):
|
||||
"""Implementation of `astype` for CSRTensor."""
|
||||
data = F.cast(x.values, dtype)
|
||||
data = x.values.astype(dtype)
|
||||
return F.make_csr_tensor(x.indptr, x.indices, data, x.shape)
|
||||
|
||||
|
||||
|
@ -1931,7 +1931,7 @@ def csr_abs(x):
|
|||
|
||||
def csr_mv(x, dense_vector):
|
||||
"""Implementation of `abs` for CSRTensor."""
|
||||
check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
|
||||
check_value_type('dense_vector', dense_vector, (Tensor,), 'CSRTensor.mv')
|
||||
return F.csr_mv(x, dense_vector)
|
||||
|
||||
|
||||
|
@ -1943,7 +1943,7 @@ def csr_to_tuple(x):
|
|||
|
||||
def coo_astype(x, dtype):
|
||||
"""Implementation of `astype` for COOTensor."""
|
||||
data = F.cast(x.values, dtype)
|
||||
data = x.values.astype(dtype)
|
||||
return F.make_coo_tensor(x.indices, data, x.shape)
|
||||
|
||||
|
||||
|
|
|
@ -2590,7 +2590,7 @@ class COOTensor(COOTensor_):
|
|||
Return a copy of the COOTensor, cast its values to a specified type.
|
||||
|
||||
Args:
|
||||
dtype (:class:`mindspore.dtype`): Designated tensor dtype.
|
||||
dtype (Union[:class:`mindspore.dtype`, numpy.dtype, str]): Designated tensor dtype.
|
||||
|
||||
Returns:
|
||||
COOTensor.
|
||||
|
@ -2798,7 +2798,7 @@ class CSRTensor(CSRTensor_):
|
|||
Return a copy of the CSRTensor, cast its values to a specified type.
|
||||
|
||||
Args:
|
||||
dtype (:class:`mindspore.dtype`): Designated tensor dtype.
|
||||
dtype (Union[:class:`mindspore.dtype`, numpy.dtype, str]): Designated tensor dtype.
|
||||
|
||||
Returns:
|
||||
CSRTensor.
|
||||
|
|
|
@ -19,43 +19,37 @@ from .. import functional as F
|
|||
from .. import operations as P
|
||||
from ..operations import _csr_ops
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..composite.multitype_ops._constexpr_utils import infer_out_shape
|
||||
from .grad_base import bprops, bprop_getters
|
||||
|
||||
|
||||
# Unused parameters are placeholders.
|
||||
|
||||
|
||||
@bprops.register("MakeCSRTensor")
|
||||
def bprop_make_csr_tensor(indptr, indices, values, dense_shape, out, dout):
|
||||
"""Backpropagator for primitive `MakeCSRTensor`."""
|
||||
res = (zeros_like(indptr), zeros_like(indices), F.csr_tensor_get_values(dout), ())
|
||||
return res
|
||||
|
||||
# COOTensor Bprop Methods
|
||||
|
||||
@bprops.register("MakeCOOTensor")
|
||||
def bprop_make_coo_tensor(indices, values, dense_shape, out, dout):
|
||||
"""Backpropagator for primitive `MakeCOOTensor`."""
|
||||
return zeros_like(indices), F.coo_tensor_get_values(dout), ()
|
||||
return (zeros_like(indices), dout.values,)
|
||||
|
||||
|
||||
@bprops.register("COOTensorGetIndices")
|
||||
def bprop_sparse_tensor_get_indices(sparse_tensor, out, dout):
|
||||
def bprop_coo_tensor_get_indices(coo_tensor, out, dout):
|
||||
"""Backpropagator for primitive `COOTensorGetIndices`."""
|
||||
return (zeros_like(sparse_tensor),)
|
||||
return (F.make_coo_tensor(dout, zeros_like(coo_tensor.values), coo_tensor.shape),)
|
||||
|
||||
|
||||
@bprops.register("COOTensorGetValues")
|
||||
def bprop_sparse_tensor_get_values(sparse_tensor, out, dout):
|
||||
def bprop_coo_tensor_get_values(coo_tensor, out, dout):
|
||||
"""Backpropagator for primitive `COOTensorGetValues`."""
|
||||
return F.make_coo_tensor(F.coo_tensor_get_indices(sparse_tensor),
|
||||
dout,
|
||||
F.coo_tensor_get_dense_shape(sparse_tensor))
|
||||
return (F.make_coo_tensor(zeros_like(coo_tensor.indices), dout, coo_tensor.shape),)
|
||||
|
||||
|
||||
@bprops.register("COOTensorrGetDenseShape")
|
||||
def bprop_sparse_tensor_get_dense_shape(sparse_tensor, out, dout):
|
||||
@bprops.register("COOTensorGetDenseShape")
|
||||
def bprop_coo_tensor_get_dense_shape(coo_tensor, out, dout):
|
||||
"""Backpropagator for primitive `COOTensorGetDenseShape`."""
|
||||
return (zeros_like(sparse_tensor),)
|
||||
return (zeros_like(coo_tensor),)
|
||||
|
||||
|
||||
@bprop_getters.register(P.SparseToDense)
|
||||
|
@ -89,6 +83,40 @@ def get_bprop_sparse_tensor_dense_matmul(self):
|
|||
return bprop
|
||||
|
||||
|
||||
# CSRTensor Bprop Methods
|
||||
|
||||
|
||||
@bprops.register("MakeCSRTensor")
|
||||
def bprop_make_csr_tensor(indptr, indices, values, dense_shape, out, dout):
|
||||
"""Backpropagator for primitive `MakeCSRTensor`."""
|
||||
res = (zeros_like(indptr), zeros_like(indices), dout.values, dout.shape)
|
||||
return res
|
||||
|
||||
|
||||
@bprops.register("CSRTensorGetIndptr")
|
||||
def bprop_csr_tensor_get_indptr(csr_tensor, out, dout):
|
||||
"""Backpropagator for primitive `CSRTensorGetIndptr`."""
|
||||
return (F.make_csr_tensor(dout, zeros_like(csr_tensor.indices), zeros_like(csr_tensor.values), csr_tensor.shape),)
|
||||
|
||||
|
||||
@bprops.register("CSRTensorGetIndices")
|
||||
def bprop_csr_tensor_get_indices(csr_tensor, out, dout):
|
||||
"""Backpropagator for primitive `CSRTensorGetIndices`."""
|
||||
return (F.make_csr_tensor(zeros_like(csr_tensor.indptr), dout, zeros_like(csr_tensor.values), csr_tensor.shape),)
|
||||
|
||||
|
||||
@bprops.register("CSRTensorGetValues")
|
||||
def bprop_csr_tensor_get_values(csr_tensor, out, dout):
|
||||
"""Backpropagator for primitive `CSRTensorGetValues`."""
|
||||
return (F.make_csr_tensor(zeros_like(csr_tensor.indptr), zeros_like(csr_tensor.indices), dout, csr_tensor.shape),)
|
||||
|
||||
|
||||
@bprops.register("CSRTensorGetDenseShape")
|
||||
def bprop_csr_tensor_get_dense_shape(csr_tensor, out, dout):
|
||||
"""Backpropagator for primitive `CSRTensorGetDenseShape`."""
|
||||
return (zeros_like(csr_tensor),)
|
||||
|
||||
|
||||
@bprop_getters.register(_csr_ops.CSRReduceSum)
|
||||
def get_bprop_csr_reduce_sum(self):
|
||||
"Back-propagation for CSRReduceSum."
|
||||
|
@ -134,7 +162,14 @@ def get_bprop_csr_mv(self):
|
|||
|
||||
@bprop_getters.register(_csr_ops.CSRMul)
|
||||
def get_bprop_csr_mul(self):
|
||||
"Back-propagation for CSRMul."
|
||||
"""
|
||||
Back-propagation for CSRMul.
|
||||
Note: Broadcast of first dimension of the dense input is not supported for `CSRDiv`,
|
||||
because this would require sparse reduce sum on the first axis, which is not logically contiguous
|
||||
for the CSR storage format. If broadcast of first dimension should be desired, the operator `/`
|
||||
could be used instead, which bypass the constraint by making use of the indices in the CSR input
|
||||
to index the dense input.
|
||||
"""
|
||||
def bprop(csr_tensor, dense, out, dout):
|
||||
indptr = csr_tensor.indptr
|
||||
indices = csr_tensor.indices
|
||||
|
@ -146,8 +181,9 @@ def get_bprop_csr_mul(self):
|
|||
dense_grad_value = F.mul(dout, values)
|
||||
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
|
||||
if len(dense.shape) == 1 or dense.shape[0] == 1:
|
||||
dense_grad = F.csr_reduce_sum(dense_grad, 0)
|
||||
elif dense.shape[1] == 1:
|
||||
raise ValueError(
|
||||
"Backpropagation for CSRMul with broadcast for the first dimension is not supported! Use `*` instead")
|
||||
if dense.shape[1] == 1:
|
||||
dense_grad = F.csr_reduce_sum(dense_grad, 1)
|
||||
else:
|
||||
row = F.csr2coo(indptr, indices.shape[0])
|
||||
|
@ -157,15 +193,61 @@ def get_bprop_csr_mul(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_csr_ops.CSRDiv)
|
||||
def get_bprop_csr_div(self):
|
||||
"""
|
||||
Back-propagation for CSRDiv.
|
||||
Note: Broadcast of first dimension of the dense input is not supported for `CSRDiv`,
|
||||
because this would require sparse reduce sum on the first axis, which is not logically contiguous
|
||||
for the CSR storage format. If broadcast of first dimension should be desired, the operator `/`
|
||||
could be used instead, which bypass the constraint by making use of the indices in the CSR input
|
||||
to index the dense input.
|
||||
"""
|
||||
def bprop(csr_tensor, dense, out, dout):
|
||||
indptr = csr_tensor.indptr
|
||||
indices = csr_tensor.indices
|
||||
shape = csr_tensor.shape
|
||||
|
||||
batch_dim_csr_start = 2
|
||||
batch_dim_dense_start = len(dense.shape) - (len(shape) - batch_dim_csr_start)
|
||||
if batch_dim_dense_start < 0:
|
||||
batch_dim_dense_start = 0
|
||||
feature_dim = infer_out_shape(shape[:batch_dim_csr_start], dense.shape[:batch_dim_dense_start])
|
||||
|
||||
shape_x = feature_dim + shape[batch_dim_csr_start:]
|
||||
shape_y = feature_dim + shape[batch_dim_dense_start:]
|
||||
reduce_x, reduce_y = F.broadcast_gradient_args(shape_x, shape_y)
|
||||
|
||||
csr_tensor_grad_value = F.csr_div(F.make_csr_tensor(indptr, indices, dout, shape), dense)
|
||||
if reduce_x:
|
||||
csr_tensor_grad_value = P.ReduceSum(True)(csr_tensor_grad_value, reduce_x)
|
||||
csr_tensor_grad = F.make_csr_tensor(indptr, indices, csr_tensor_grad_value, shape)
|
||||
dense_grad_value = F.neg_tensor(F.mul(out, csr_tensor_grad_value))
|
||||
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
|
||||
if len(dense.shape) == 1 or dense.shape[0] == 1:
|
||||
raise ValueError(
|
||||
"Backpropagation for CSRDiv with broadcast for the first dimension is not supported! Use `/` instead")
|
||||
if dense.shape[1] == 1:
|
||||
dense_grad = F.csr_reduce_sum(dense_grad, 1)
|
||||
else:
|
||||
row = F.csr2coo(indptr, indices.shape[0])
|
||||
coo_idx = P.Stack(-1)((row, indices))
|
||||
dense_grad = F.tensor_scatter_update(zeros_like(dense), coo_idx, dense_grad_value)
|
||||
if reduce_y:
|
||||
dense_grad = P.ReduceSum(True)(csr_tensor_grad_value, reduce_y)
|
||||
return csr_tensor_grad, dense_grad
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_csr_ops.CSR2COO)
|
||||
def get_bprop_csr2coo(self):
|
||||
def bprop(indptr, nnz, out, dout):
|
||||
return zeros_like(dout)
|
||||
return zeros_like(indptr), zeros_like(nnz)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_csr_ops.COO2CSR)
|
||||
def get_bprop_coo2csr(self):
|
||||
def bprop(row_indices, height, out, dout):
|
||||
return zeros_like(dout)
|
||||
return zeros_like(row_indices), zeros_like(height)
|
||||
return bprop
|
||||
|
|
|
@ -361,4 +361,19 @@ def _add_csrtensor_csrtensor(x, y):
|
|||
"""
|
||||
return F.make_csr_tensor(x.indptr, x.indices, x.values + y.values, x.shape)
|
||||
|
||||
|
||||
@_add_backward.register("COOTensor", "COOTensor")
|
||||
def _add_cootensor_cootensor(x, y):
|
||||
"""
|
||||
Adds COOTensor and COOTensor.
|
||||
|
||||
Args:
|
||||
x (COOTensor): x
|
||||
y (COOTensor): y
|
||||
|
||||
Returns:
|
||||
COOTensor.
|
||||
"""
|
||||
return F.make_coo_tensor(x.indices, x.values + y.values, x.shape)
|
||||
|
||||
hyper_add = base.HyperMap(_add_backward)
|
||||
|
|
|
@ -1026,7 +1026,7 @@ make_csr_tensor = Primitive('MakeCSRTensor')
|
|||
csr_tensor_get_values = Primitive('CSRTensorGetValues')
|
||||
csr_tensor_get_indices = Primitive('CSRTensorGetIndices')
|
||||
csr_tensor_get_indptr = Primitive('CSRTensorGetIndptr')
|
||||
csr_tensor_get_shape = Primitive('CSRTensorGetDenseShape')
|
||||
csr_tensor_get_dense_shape = Primitive('CSRTensorGetDenseShape')
|
||||
|
||||
tensor_operator_registry.register('all', P.ReduceAll)
|
||||
tensor_operator_registry.register('any', P.ReduceAny)
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""setup for pytest in mindspore.sparse"""
|
||||
import mindspore.context as context
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def setup_module(module):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""common utils for sparse tests"""
|
||||
import platform
|
||||
|
||||
|
||||
def get_platform():
|
||||
return platform.system().lower()
|
||||
|
||||
|
||||
def compare_res(tensor_tup, numpy_tup):
|
||||
assert len(tensor_tup) == len(numpy_tup)
|
||||
for item in zip(tensor_tup, numpy_tup):
|
||||
assert (item[0].asnumpy() == item[1]).all()
|
|
@ -14,19 +14,15 @@
|
|||
# ============================================================================
|
||||
"""smoke tests for COO operations"""
|
||||
|
||||
import platform
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, COOTensor, ms_function, nn, context
|
||||
from mindspore import Tensor, COOTensor, ms_function, nn, ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
from .sparse_utils import get_platform, compare_res
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
def get_platform():
|
||||
return platform.system().lower()
|
||||
|
||||
def compare_coo(coo1, coo2):
|
||||
assert isinstance(coo1, COOTensor)
|
||||
|
@ -219,3 +215,96 @@ def test_coo_attr():
|
|||
assert (py_tuple[i].asnumpy() == g_tuple[i].asnumpy()).all()
|
||||
else:
|
||||
assert py_tuple[i] == g_tuple[i]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_coo_bprop():
|
||||
"""
|
||||
Feature: Test back-propagation with COO-related Ops.
|
||||
Description: Test back-propagation of make_coo, coo.attributes, coo.methods().
|
||||
Expectation: Success.
|
||||
"""
|
||||
if get_platform() != "linux":
|
||||
return
|
||||
grad_op = ops.GradOperation(get_all=True)
|
||||
indices = Tensor([[0, 1], [1, 2]], dtype=mstype.int32)
|
||||
values = Tensor([-1, 2], dtype=mstype.float32)
|
||||
dense_shape = (3, 4)
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_coo_tensor(indices, values, dense_shape):
|
||||
coo_tensor = COOTensor(indices, values, dense_shape)
|
||||
return coo_tensor
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_coo_indices(indices, values, dense_shape):
|
||||
coo_tensor = COOTensor(indices, values, dense_shape)
|
||||
return coo_tensor.indices
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_coo_values(indices, values, dense_shape):
|
||||
coo_tensor = COOTensor(indices, values, dense_shape)
|
||||
return coo_tensor.values
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_coo_shape(indices, values, dense_shape):
|
||||
coo_tensor = COOTensor(indices, values, dense_shape)
|
||||
return coo_tensor.shape
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_coo_cast(indices, values, dense_shape):
|
||||
coo_tensor = COOTensor(indices, values, dense_shape)
|
||||
return coo_tensor.astype(mstype.int32)
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_coo_dtype(indices, values, dense_shape):
|
||||
coo_tensor = COOTensor(indices, values, dense_shape)
|
||||
return coo_tensor.dtype
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_coo_to_tuple(indices, values, dense_shape):
|
||||
coo_tensor = COOTensor(indices, values, dense_shape)
|
||||
return coo_tensor.to_tuple()
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_coo_to_abs(indices, values, dense_shape):
|
||||
coo_tensor = COOTensor(indices, values, dense_shape)
|
||||
return coo_tensor.abs()
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_coo_to_csr(indices, values, dense_shape):
|
||||
coo_tensor = COOTensor(indices, values, dense_shape)
|
||||
return coo_tensor.to_csr()
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_coo_to_dense(indices, values, dense_shape):
|
||||
coo_tensor = COOTensor(indices, values, dense_shape)
|
||||
return coo_tensor.to_dense()
|
||||
|
||||
all_zero = (np.zeros(indices.shape, np.int32), np.zeros(values.shape, np.float32))
|
||||
values_on = (np.zeros(indices.shape, np.int32), np.ones(values.shape, np.float32))
|
||||
values_absgrad = (np.zeros(indices.shape, np.int32), np.sign(values.asnumpy()))
|
||||
|
||||
compare_res(test_coo_tensor(indices, values, dense_shape), values_on)
|
||||
compare_res(test_coo_indices(indices, values, dense_shape), all_zero)
|
||||
compare_res(test_coo_values(indices, values, dense_shape), values_on)
|
||||
compare_res(test_coo_shape(indices, values, dense_shape), all_zero)
|
||||
compare_res(test_coo_cast(indices, values, dense_shape), values_on)
|
||||
compare_res(test_coo_dtype(indices, values, dense_shape), all_zero)
|
||||
compare_res(test_coo_to_tuple(indices, values, dense_shape), values_on)
|
||||
compare_res(test_coo_to_abs(indices, values, dense_shape), values_absgrad)
|
||||
compare_res(test_coo_to_csr(indices, values, dense_shape), values_on)
|
||||
compare_res(test_coo_to_dense(indices, values, dense_shape), values_on)
|
||||
|
|
|
@ -15,18 +15,17 @@
|
|||
"""smoke tests for CSR operations"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, CSRTensor, ms_function, nn, context, ops
|
||||
from mindspore import Tensor, CSRTensor, ms_function, nn, ops
|
||||
from mindspore.ops.operations import _csr_ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import export, load
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
from .sparse_utils import get_platform, compare_res
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
def compare_csr(csr1, csr2):
|
||||
assert isinstance(csr1, CSRTensor)
|
||||
|
@ -36,8 +35,6 @@ def compare_csr(csr1, csr2):
|
|||
assert (csr1.values.asnumpy() == csr2.values.asnumpy()).all()
|
||||
assert csr1.shape == csr2.shape
|
||||
|
||||
def get_platform():
|
||||
return platform.system().lower()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -494,10 +491,10 @@ def test_dtype_csr_tensor():
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_bprop():
|
||||
def test_bprop():
|
||||
"""
|
||||
Feature: Test back-propagation with CSR-related Ops.
|
||||
Description: Test CSRReduceSum, CSRMul, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense().
|
||||
Description: Test CSRReduceSum, CSRMul, CSRDiv, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense().
|
||||
Expectation: Success.
|
||||
"""
|
||||
if get_platform() != "linux":
|
||||
|
@ -506,22 +503,26 @@ def test_csr_bprop():
|
|||
csrmv = _csr_ops.CSRMV()
|
||||
grad_op = ops.GradOperation(get_all=True)
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_mul(csr_tensor, dense):
|
||||
return csr_tensor * dense
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_div(csr_tensor, dense):
|
||||
return csr_tensor / dense
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_reduce_sum(csr_tensor, axis):
|
||||
return csr_reduce_sum(csr_tensor, axis)
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csrmv(csr_tensor, dense):
|
||||
return csrmv(csr_tensor, dense)
|
||||
|
||||
test_csr_mul_grad_pynative = grad_op(test_csr_mul)
|
||||
test_csr_mul_grad_graph = ms_function(test_csr_mul_grad_pynative)
|
||||
test_csr_reduce_sum_grad_pynative = grad_op(test_csr_reduce_sum)
|
||||
test_csr_reduce_sum_grad_graph = ms_function(test_csr_reduce_sum_grad_pynative)
|
||||
test_csrmv_grad_pynative = grad_op(test_csrmv)
|
||||
test_csrmv_grad_graph = ms_function(test_csrmv_grad_pynative)
|
||||
|
||||
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
|
||||
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
|
||||
values = Tensor(np.arange(6), dtype=mstype.float32)
|
||||
|
@ -531,39 +532,43 @@ def test_csr_bprop():
|
|||
csr_mv_arg = Tensor([[1], [2], [3], [4]], dtype=mstype.float32)
|
||||
csr_mv_expect_1 = np.array([4, 1, 2, 3, 2, 4], dtype=np.float32)
|
||||
csr_mv_expect_2 = np.array([[1], [6], [3], [5]], dtype=np.float32)
|
||||
csr_mv_output_1, csr_mv_output_2 = test_csrmv_grad_pynative(csr_tensor, csr_mv_arg)
|
||||
assert np.allclose(csr_mv_output_1.values.asnumpy(), csr_mv_expect_1)
|
||||
assert np.allclose(csr_mv_output_2.asnumpy(), csr_mv_expect_2)
|
||||
csr_mv_output_1, csr_mv_output_2 = test_csrmv_grad_graph(csr_tensor, csr_mv_arg)
|
||||
csr_mv_output_1, csr_mv_output_2 = test_csrmv(csr_tensor, csr_mv_arg)
|
||||
assert np.allclose(csr_mv_output_1.values.asnumpy(), csr_mv_expect_1)
|
||||
assert np.allclose(csr_mv_output_2.asnumpy(), csr_mv_expect_2)
|
||||
|
||||
csr_reduce_sum_expect = np.ones(6, dtype=np.float32)
|
||||
csr_reduce_sum_output = test_csr_reduce_sum_grad_pynative(csr_tensor, 1)
|
||||
assert np.allclose(csr_reduce_sum_output[0].values.asnumpy(), csr_reduce_sum_expect)
|
||||
csr_reduce_sum_output = test_csr_reduce_sum_grad_graph(csr_tensor, 1)
|
||||
assert np.allclose(csr_reduce_sum_output[0].values.asnumpy(), csr_reduce_sum_expect)
|
||||
csr_reduce_sum_expect_1 = np.ones(6, dtype=np.float32)
|
||||
csr_reduce_sum_output_1 = test_csr_reduce_sum(csr_tensor, 1)
|
||||
assert np.allclose(csr_reduce_sum_output_1[0].values.asnumpy(), csr_reduce_sum_expect_1)
|
||||
|
||||
csr_mul_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32)
|
||||
csr_mul_expect_1_1 = np.array([1, 2, 2, 2, 3, 3], dtype=np.float32)
|
||||
csr_mul_expect_1_2 = np.array([[0], [6], [9]], dtype=np.float32)
|
||||
csr_mul_output_1_1, csr_mul_output_1_2 = test_csr_mul_grad_pynative(csr_tensor, csr_mul_arg_1)
|
||||
assert np.allclose(csr_mul_output_1_1.values.asnumpy(), csr_mul_expect_1_1)
|
||||
assert np.allclose(csr_mul_output_1_2.asnumpy(), csr_mul_expect_1_2)
|
||||
csr_mul_output_1_1, csr_mul_output_1_2 = test_csr_mul_grad_graph(csr_tensor, csr_mul_arg_1)
|
||||
csr_mul_output_1_1, csr_mul_output_1_2 = test_csr_mul(csr_tensor, csr_mul_arg_1)
|
||||
assert np.allclose(csr_mul_output_1_1.values.asnumpy(), csr_mul_expect_1_1)
|
||||
assert np.allclose(csr_mul_output_1_2.asnumpy(), csr_mul_expect_1_2)
|
||||
|
||||
csr_mul_arg_2 = Tensor(np.arange(12).reshape(3, 4), dtype=mstype.float32)
|
||||
csr_mul_expect_2_1 = np.array([3, 4, 5, 6, 9, 11], dtype=np.float32)
|
||||
csr_mul_expect_2_2 = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
|
||||
csr_mul_output_2_1, csr_mul_output_2_2 = test_csr_mul_grad_pynative(csr_tensor, csr_mul_arg_2)
|
||||
assert np.allclose(csr_mul_output_2_1.values.asnumpy(), csr_mul_expect_2_1)
|
||||
assert np.allclose(csr_mul_output_2_2.asnumpy(), csr_mul_expect_2_2)
|
||||
csr_mul_output_2_1, csr_mul_output_2_2 = test_csr_mul_grad_graph(csr_tensor, csr_mul_arg_2)
|
||||
csr_mul_output_2_1, csr_mul_output_2_2 = test_csr_mul(csr_tensor, csr_mul_arg_2)
|
||||
assert np.allclose(csr_mul_output_2_1.values.asnumpy(), csr_mul_expect_2_1)
|
||||
assert np.allclose(csr_mul_output_2_2.asnumpy(), csr_mul_expect_2_2)
|
||||
|
||||
csr_div_expect_1_1 = np.array([1, 0.5, 0.5, 0.5, 0.3333333, 0.3333333], dtype=np.float32)
|
||||
csr_div_expect_1_2 = np.array([[0], [-1.5], [-1]], dtype=np.float32)
|
||||
csr_div_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32)
|
||||
csr_div_output_1_1, csr_div_output_1_2 = test_csr_div(csr_tensor, csr_div_arg_1)
|
||||
assert np.allclose(csr_div_output_1_1.values.asnumpy(), csr_div_expect_1_1)
|
||||
assert np.allclose(csr_div_output_1_2.asnumpy(), csr_div_expect_1_2)
|
||||
|
||||
csr_div_arg_2 = Tensor(np.arange(1, 13).reshape(3, 4), dtype=mstype.float32)
|
||||
csr_div_expect_2_1 = np.array([0.25, 0.2, 0.16666667, 0.14285715, 0.1, 0.0833333], dtype=np.float32)
|
||||
csr_div_expect_2_2 = np.array(
|
||||
[[0, 0, 0, 0], [-0.04, -0.05555556, -0.06122449, 0], [0, -0.04, 0, -0.03472222]], dtype=np.float32)
|
||||
csr_div_output_2_1, csr_div_output_2_2 = test_csr_div(csr_tensor, csr_div_arg_2)
|
||||
assert np.allclose(csr_div_output_2_1.values.asnumpy(), csr_div_expect_2_1)
|
||||
assert np.allclose(csr_div_output_2_2.asnumpy(), csr_div_expect_2_2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -600,3 +605,103 @@ def test_csr_method():
|
|||
to_dense_output = CSRToDenseNet()(csr_tensor)
|
||||
to_dense_expect = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
|
||||
assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bprop2():
|
||||
"""
|
||||
Feature: Test back-propagation with CSR-related Ops.
|
||||
Description: Test back-propagation of make_csr, csr.attributes, csr.methods().
|
||||
Expectation: Success.
|
||||
"""
|
||||
if get_platform() != "linux":
|
||||
return
|
||||
grad_op = ops.GradOperation(get_all=True)
|
||||
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
|
||||
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
|
||||
values = Tensor(np.arange(6) - 3.5, dtype=mstype.float32)
|
||||
dense_shape = (3, 4)
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_tensor(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return csr_tensor
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_indptr(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return csr_tensor.indptr
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_indices(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return csr_tensor.indices
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_values(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return csr_tensor.values
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_shape(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return csr_tensor.shape
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_cast(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return csr_tensor.astype(mstype.int32)
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_dtype(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return csr_tensor.dtype
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_to_tuple(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return csr_tensor.to_tuple()
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_to_abs(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return csr_tensor.abs()
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_to_coo(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return csr_tensor.to_coo()
|
||||
|
||||
@grad_op
|
||||
@ms_function
|
||||
def test_csr_to_dense(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return csr_tensor.to_dense()
|
||||
|
||||
all_zero = (np.zeros(indptr.shape, np.int32), np.zeros(indices.shape, np.int32), np.zeros(values.shape, np.float32))
|
||||
values_on = (np.zeros(indptr.shape, np.int32), np.zeros(indices.shape, np.int32), np.ones(values.shape, np.float32))
|
||||
values_absgrad = (np.zeros(indptr.shape, np.int32), np.zeros(indices.shape, np.int32), np.sign(values.asnumpy()))
|
||||
compare_res(test_csr_tensor(indptr, indices, values, dense_shape), values_on)
|
||||
compare_res(test_csr_indptr(indptr, indices, values, dense_shape), all_zero)
|
||||
compare_res(test_csr_indices(indptr, indices, values, dense_shape), all_zero)
|
||||
compare_res(test_csr_values(indptr, indices, values, dense_shape), values_on)
|
||||
compare_res(test_csr_cast(indptr, indices, values, dense_shape), values_on)
|
||||
compare_res(test_csr_shape(indptr, indices, values, dense_shape), all_zero)
|
||||
compare_res(test_csr_dtype(indptr, indices, values, dense_shape), all_zero)
|
||||
compare_res(test_csr_to_tuple(indptr, indices, values, dense_shape), values_on)
|
||||
compare_res(test_csr_to_abs(indptr, indices, values, dense_shape), values_absgrad)
|
||||
compare_res(test_csr_to_coo(indptr, indices, values, dense_shape), values_on)
|
||||
compare_res(test_csr_to_dense(indptr, indices, values, dense_shape), values_on)
|
||||
|
|
Loading…
Reference in New Issue