forked from mindspore-Ecosystem/mindspore
!39339 add csrtodense func and tensor method
Merge pull request !39339 from wangrao124/pr_add_csr_to_dense_func
This commit is contained in:
commit
f79c1ec33e
|
@ -2110,8 +2110,7 @@ def csr_to_coo(x):
|
|||
|
||||
def csr_to_dense(x):
|
||||
"""convert csr to dense."""
|
||||
coo_tensor = x.to_coo()
|
||||
return coo_tensor.to_dense()
|
||||
return F.csr_to_dense(x)
|
||||
|
||||
|
||||
def random_categorical_(x, num_sample, seed=0, dtype=mstype.int64):
|
||||
|
|
|
@ -1991,10 +1991,10 @@ class Tensor(Tensor_):
|
|||
perm = tuple(range(0, self.ndim))
|
||||
if axis2 + 1 < self.ndim:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:]
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:]
|
||||
else:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1]
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1]
|
||||
|
||||
return tensor_operator_registry.get('transpose')()(self, new_perm)
|
||||
|
||||
|
@ -5373,8 +5373,7 @@ class CSRTensor(CSRTensor_):
|
|||
Supported Platforms:
|
||||
``GPU``
|
||||
"""
|
||||
coo_tensor = self.to_coo()
|
||||
return coo_tensor.to_dense()
|
||||
return tensor_operator_registry.get("csr_to_dense")(self)
|
||||
|
||||
def astype(self, dtype):
|
||||
"""
|
||||
|
@ -5511,7 +5510,6 @@ class CSRTensor(CSRTensor_):
|
|||
data = self.values.abs()
|
||||
return CSRTensor(self.indptr, self.indices, data, self.shape)
|
||||
|
||||
|
||||
def add(self, b, alpha, beta):
|
||||
"""
|
||||
Addition of two CSR Tensors : C = alpha * A + beta * B
|
||||
|
|
|
@ -14,16 +14,17 @@
|
|||
# ============================================================================
|
||||
|
||||
"""bprop primitives"""
|
||||
from .._utils.utils import is_shape_unknown
|
||||
from .grad_base import bprops, bprop_getters
|
||||
from ..composite.multitype_ops._constexpr_utils import infer_out_shape
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations._sparse_grad_ops import SparseAddGrad
|
||||
from ...common import dtype as mstype
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ..operations import _csr_ops
|
||||
from ..operations.sparse_ops import SparseAdd
|
||||
from ..operations._sparse_grad_ops import SparseAddGrad
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..composite.multitype_ops._constexpr_utils import infer_out_shape
|
||||
from .grad_base import bprops, bprop_getters
|
||||
from .._utils.utils import is_shape_unknown
|
||||
from ..operations.sparse_ops import SparseAdd, CSRSparseMatrixToDense, CSRSparseMatrixToSparseTensor, \
|
||||
DenseToCSRSparseMatrix
|
||||
|
||||
# Unused parameters are placeholders.
|
||||
|
||||
|
@ -99,6 +100,7 @@ def get_bprop_sparse_add(self):
|
|||
shape_op = P.Shape()
|
||||
dyn_shape_op = P.TensorShape()
|
||||
reshape = P.Reshape()
|
||||
|
||||
def bprop(x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape, thresh, out, dout):
|
||||
dx1, dx2 = sparse_add_grad(dout[1], x1_indices, x2_indices, out[0])
|
||||
shp = shape_op(x1_indices)
|
||||
|
@ -299,3 +301,10 @@ def get_bprop_coo2csr(self):
|
|||
def bprop(row_indices, height, out, dout):
|
||||
return zeros_like(row_indices), zeros_like(height)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprops.register(CSRSparseMatrixToDense)
|
||||
def bprop_csr_sparse_matrix_to_dense(shape, batch, indptr, indices, values, out, dout):
|
||||
"""Backpropagator for primitive `CSRSparseMatrixToDense`."""
|
||||
index, _, _ = CSRSparseMatrixToSparseTensor()(shape, batch, indptr, indices, values)
|
||||
return DenseToCSRSparseMatrix()(dout, index)
|
||||
|
|
|
@ -327,6 +327,7 @@ from .sparse_func import (
|
|||
sparse_concat,
|
||||
csr_add,
|
||||
csr_softmax,
|
||||
csr_to_dense,
|
||||
)
|
||||
from .random_func import (
|
||||
standard_laplace,
|
||||
|
|
|
@ -22,7 +22,8 @@ from ..operations.sparse_ops import (
|
|||
SparseConcat,
|
||||
SparseAdd,
|
||||
SparseMatrixAdd,
|
||||
SparseMatrixSoftmax
|
||||
SparseMatrixSoftmax,
|
||||
CSRSparseMatrixToDense
|
||||
)
|
||||
from ..operations.array_ops import GatherNd, Coalesce
|
||||
from ..operations import _csr_ops
|
||||
|
@ -36,6 +37,7 @@ dense_to_csr = DenseToCSRSparseMatrix()
|
|||
csr_sparse_matrix_to_sparse_tensor = CSRSparseMatrixToSparseTensor()
|
||||
batch_csr_pointers_empty = Tensor([0, -1], dtype=mstype.int32)
|
||||
coalesce_op = Coalesce()
|
||||
csr_sparse_matrix_to_dense = CSRSparseMatrixToDense()
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -254,6 +256,45 @@ def csr_to_coo(tensor):
|
|||
return COOTensor(indices, values, shape)
|
||||
|
||||
|
||||
def csr_to_dense(csr_tensor):
|
||||
"""
|
||||
Converts a CSRTensor to its dense form.
|
||||
|
||||
Note:
|
||||
Only 2-D CSRTensor is supported for now.
|
||||
|
||||
Args:
|
||||
csr_tensor: A CSRTensor, must be 2-D.
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If input is not a CSRTensor.
|
||||
ValueError: If input CSRTensor is not 2-D.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor, CSRTensor, ops
|
||||
>>> indptr = Tensor([0, 1, 2]).astype("int32")
|
||||
>>> indices = Tensor([0, 1]).astype("int32")
|
||||
>>> values = Tensor([2, 1]).astype("float32")
|
||||
>>> shape = (2, 4)
|
||||
>>> x = CSRTensor(indptr, indices, values, shape)
|
||||
>>> output = ops.csr_to_dense(x)
|
||||
>>> print(output)
|
||||
"""
|
||||
if not isinstance(csr_tensor, CSRTensor):
|
||||
raise_type_error("For functional operator csr_to_dense, input argument must be a CSRTensor.")
|
||||
if len(csr_tensor.shape) != 2:
|
||||
raise_value_error("Currently only support 2-D CSRTensor when converting to COOTensor.")
|
||||
shape = csr_tensor.shape
|
||||
return csr_sparse_matrix_to_dense(Tensor(shape, dtype=mstype.int32), batch_csr_pointers_empty,
|
||||
csr_tensor.indptr, csr_tensor.indices, csr_tensor.values)
|
||||
|
||||
|
||||
# deprecated, will be removed once `csr_to_coo` supports all backends.
|
||||
csr2coo = _csr_ops.CSR2COO()
|
||||
|
||||
|
@ -644,7 +685,8 @@ __all__ = [
|
|||
'sparse_add',
|
||||
'sparse_concat',
|
||||
'csr_add',
|
||||
'csr_softmax'
|
||||
'csr_softmax',
|
||||
'csr_to_dense'
|
||||
]
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -448,6 +448,7 @@ tensor_operator_registry.register('csr_mm', _csr_ops.CSRMM)
|
|||
tensor_operator_registry.register('csr_reduce_sum', csr_reduce_sum)
|
||||
tensor_operator_registry.register('dense_to_sparse_csr', dense_to_sparse_csr)
|
||||
tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)
|
||||
tensor_operator_registry.register('csr_to_dense', csr_to_dense)
|
||||
tensor_operator_registry.register('narrow', narrow)
|
||||
tensor_operator_registry.register('sort', sort)
|
||||
tensor_operator_registry.register('csr_to_coo', csr_to_coo)
|
||||
|
|
Loading…
Reference in New Issue