!44859 CSRTensor & functional SparseMatrixSparseMatMul

Merge pull request !44859 from mamba_ni/master
This commit is contained in:
i-robot 2022-11-07 02:34:14 +00:00 committed by Gitee
commit dfd2239118
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 132 additions and 21 deletions

View File

@ -418,6 +418,7 @@ Array操作
mindspore.ops.dense_to_sparse_coo
mindspore.ops.dense_to_sparse_csr
mindspore.ops.csr_add
mindspore.ops.csr_mm
mindspore.ops.csr_softmax
mindspore.ops.csr_to_coo
mindspore.ops.sparse_add

View File

@ -78,19 +78,19 @@ mindspore.CSRTensor
返回每个非零元素所占字节数。
.. py:method:: mm(dense_matrix: Tensor)
.. py:method:: mm(matrix: Union[Tensor, CSRTensor])
返回CSRTensor右乘稠密矩阵的矩阵乘法运算结果。
shape为 `[M, N]` 的CSRTensor需要适配shape为 `[N, K]` 的稠密矩阵,得到结果为 `[M, K]` 的稠密矩阵。
返回CSRTensor右乘稀疏矩阵或稠密矩阵的矩阵乘法运算结果。
shape为 `[M, N]` 的CSRTensor需要适配shape为 `[N, K]` 的稠密矩阵或稀疏矩阵,得到结果为 `[M, K]` 的稠密矩阵或稀疏矩阵。
.. note::
如果运行后端是CPU那么仅支持在安装了LLVM12.0.1的机器运行。
参数:
- **dense_matrix** (Tensor) - shape为 `[NK]` 的二维矩阵其中N等于CSRTensor的列数。
- **matrix** (Tensor or CSRTensor) - shape为 `[NK]` 的二维矩阵其中N等于CSRTensor的列数。
返回:
Tensor。
Tensor or CSRTensor
.. py:method:: mv(dense_vector: Tensor)

View File

@ -0,0 +1,21 @@
mindspore.ops.csr_mm
=================================
.. py:function:: mindspore.ops.csr_mm(a: CSRTensor, b: CSRTensor or Tensor, trans_a: bool, trans_b: bool, adjoint_a: bool, adjoint_b: bool)
返回稀疏矩阵a与稀疏矩阵或稠密矩阵b的矩阵乘法结果。
.. note::
若右矩阵为Tensor则仅支持安装了LLVM12.0.1的CPU后端或GPU后端。
若右矩阵为CSRTensor 则仅支持GPU后端。
参数:
- **a** (CSRTensor) - 稀疏的 CSRTensor。
- **b** (CSRTensor 或 Tensor) - 稀疏的 CSRTensor或稠密矩阵。
- **trans_a** (Tensor) - 是否对矩阵a进行转置。
- **trans_b** (Tensor) - 是否对矩阵b进行转置。
- **adjoint_a** (Tensor) - 是否对矩阵a进行共轭。
- **adjoint_b** (Tensor) - 是否对矩阵b进行共轭。
返回:
返回稀疏矩阵类型为CSRTensor。

View File

@ -418,6 +418,7 @@ Sparse Functions
mindspore.ops.dense_to_sparse_coo
mindspore.ops.dense_to_sparse_csr
mindspore.ops.csr_add
mindspore.ops.csr_mm
mindspore.ops.csr_softmax
mindspore.ops.csr_to_coo
mindspore.ops.sparse_add

View File

@ -36,8 +36,8 @@ from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
from ...ops.operations.math_ops import Median
from ...ops.operations._inner_ops import Format
from ...ops.operations import _map_tensor_ops
from ...ops.operations import _csr_ops
from ...ops.operations import _map_tensor_ops
from ...ops.primitive import constexpr
from ...common import dtype as mstype
@ -3178,9 +3178,11 @@ def csr_mv(x, dense_vector):
return F.csr_mv(x, dense_vector)
def csr_mm(x, dense):
def csr_mm(x, matrix):
"""Implementation of `mm` for CSRTensor."""
return _csr_mm(x.indptr, x.indices, x.values, x.shape, dense)
if isinstance(matrix, CSRTensor):
return F.csr_mm(x, matrix)
return _csr_mm(x.indptr, x.indices, x.values, x.shape, matrix)
def csr_to_tuple(x):

View File

@ -775,23 +775,25 @@ class CSRTensor(CSRTensor_):
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
return tensor_operator_registry.get("csr_mv")(self, dense_vector)
def mm(self, dense_matrix: Tensor) -> Tensor:
def mm(self, matrix: Union[Tensor, CSRTensor]) -> Union[Tensor, CSRTensor]:
"""
Return the matrix multiplication result of the right-multiply dense matrix of the CSRTensor.
The CSRTensor with shape `[M, N]` needs to adapt the dense matrix with shape `[N, K]`
to get the dense matrix with result `[M, K]`.
Return the matrix multiplication result of the right-multiply matrixdense or CSRTensor of the CSRTensor.
The CSRTensor with shape `[M, N]` needs to adapt the right matrix with shape `[N, K]`
to get the dense matrix or CSRTensor with result `[M, K]`.
Note:
Currently only supports CPU backend with LLVM 12.0.1 installed.
If right matrix is CSRTensor, currently only supports GPU backend.
if right matrix is Tensor, currently supports CPU backend with LLVM 12.0.1 or GPU backend.
Args:
dense_matrix (Tensor): A dense Tensor, its shape[0] should be equal to csr_tensor.shape[1]
matrix (Tensor or CSRTensor): A dense Tensor or CSRTensor,
its shape[0] should be equal to csr_tensor.shape[1]
Returns:
Tensor.
Tensor or CSRTensor.
Supported Platforms:
``GPU`` ``CPU``
``CPU`` ``GPU``
Examples:
>>> from mindspore import Tensor, CSRTensor
@ -806,9 +808,11 @@ class CSRTensor(CSRTensor_):
[[2. 4.]
[1. 2.]]
"""
validator.check_value_type('dense_matrix', dense_matrix, (Tensor_,), 'CSRTensor.mm')
return tensor_operator_registry.get("csr_mm")()(self.indptr, self.indices, self.values,
self.shape, dense_matrix)
if isinstance(matrix, CSRTensor):
return tensor_operator_registry.get("csr_mm")(self, matrix)
validator.check_value_type('matrix', matrix, (Tensor_,), 'CSRTensor.mm')
return tensor_operator_registry.get("csr_mm_akg")()(self.indptr, self.indices, self.values,
self.shape, matrix)
def sum(self, axis: int) -> Tensor:
"""

View File

@ -376,6 +376,7 @@ from .sparse_func import (
csr_gather,
csr_mul,
csr_mv,
csr_mm,
csr_reduce_sum,
csr_to_coo,
csr2coo,

View File

@ -24,6 +24,7 @@ from mindspore.ops.operations.sparse_ops import (
SparseAdd,
SparseMatrixAdd,
SparseMatrixSoftmax,
SparseMatrixSparseMatMul,
CSRSparseMatrixToDense
)
from mindspore.common import dtype as mstype
@ -227,6 +228,84 @@ def csr_mv(csr_tensor: CSRTensor, dense: Tensor) -> Tensor:
return _csr_ops.CSRMV()(csr_tensor.indptr, csr_tensor.indices, csr_tensor_values, csr_tensor.shape, dense)
def csr_mm(a: CSRTensor, b: CSRTensor, trans_a: bool = False, trans_b: bool = False,
adjoint_a: bool = False, adjoint_b: bool = False):
"""
Return the matrix multiplication result of the right-multiply matrixdense or CSRTensor of the CSRTensor.
The CSRTensor with shape `[M, N]` needs to adapt the right matrix with shape `[N, K]`
to get the dense matrix or CSRTensor with result `[M, K]`.
Note:
Currently supports GPU backend with right matrix is CSRTensor.
Args:
a (CSRTensor): Sparse CSR Tensor, rank should be 2.
b (CSRTensor): Sparse CSR Tensor, rank should be 2.
trans_a (bool): whether to transpose CSRTensor a.
trans_a (bool): whether to transpose CSRTensor b.
adjoint_a (bool): whether to adjoint CSRTensor a.
adjoint_b (bool): whether to adjoint CSRTensor b.
Returns:
CSRTensor.
Supported Platforms:
``GPU``
Examples:
>>> from mindspore import Tensor, CSRTensor
>>> from mindspore import dtype as mstype
>>> a_shape = (4, 5)
>>> a_indptr = Tensor([0, 1, 1, 3, 4], dtype=mstype.int32)
>>> a_indices = Tensor([0, 3, 4, 0],dtype=mstype.int32)
>>> a_values = Tensor([1.0, 5.0, -1.0, -2.0], dtype=mstype.float32)
>>> b_shape = (5, 3)
>>> b_indptr = Tensor([0, 1, 1, 3, 3, 3], dtype=mstype.int32)
>>> b_indices = Tensor([0, 0, 1],dtype=mstype.int32)
>>> b_values = Tensor([2.0, 7.0, 8.0], dtype=mstype.float32)
>>> a = CSRTensor(a_indptr, a_indices, a_values, a_shape)
>>> b = CSRTensor(b_indptr, b_indices, b_values, b_shape)
>>> c = csr_mm(a, b)
>>> print(c.shape)
(4, 3)
>>> print(c.values)
[2. -4.]
>>> print(c.indptr)
[0 1 1 1 2]
>>> print(c.indices)
[0 0]
"""
if isinstance(a, CSRTensor) and isinstance(b, CSRTensor):
a_batch_pointers = make_tensor([0, a.values.shape[0]], a.indices.dtype)
b_batch_pointers = make_tensor([0, b.values.shape[0]], b.indices.dtype)
a_shape = make_tensor(a.shape, a.indices.dtype)
b_shape = make_tensor(b.shape, b.indices.dtype)
sparse_matrix_sparse_matmul = SparseMatrixSparseMatMul(transpose_a=trans_a,
transpose_b=trans_b,
adjoint_a=adjoint_a,
adjoint_b=adjoint_b)
_, _, c_indptr, c_indices, c_values = sparse_matrix_sparse_matmul(a_shape,
a_batch_pointers,
a.indptr,
a.indices,
a.values,
b_shape,
b_batch_pointers,
b.indptr,
b.indices,
b.values)
m, a2 = a.shape
b1, k = b.shape
if trans_a or adjoint_a:
m = a2
if trans_b or adjoint_b:
k = b1
return CSRTensor(c_indptr, c_indices, c_values, (m, k))
raise_type_error("For functional operator csr_mm, inputs a and b must be type of CSRTensor currently.")
return None
def csr_reduce_sum(csr_tensor: CSRTensor, axis: int) -> Tensor:
"""
Reduces a dimension of a CSRTensor by summing all elements in the dimension.

View File

@ -28,11 +28,12 @@ from mindspore._checkparam import Rel
from mindspore.ops import _constants
from mindspore.ops.function import *
from mindspore.ops.function.sparse_func import sparse_add
from mindspore.ops.function.sparse_func import csr_mm
from mindspore.ops.primitive import constexpr, Primitive
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops
from mindspore.ops.operations import _csr_ops
from mindspore.ops.operations import _inner_ops
from mindspore.ops.operations import _csr_ops
from mindspore.ops.operations import linalg_ops
from mindspore.ops.operations.math_ops import Median
from mindspore.ops.operations.array_ops import UniqueConsecutive
@ -472,7 +473,8 @@ tensor_operator_registry.register('csr2coo', csr2coo)
tensor_operator_registry.register('coo2csr', coo2csr)
tensor_operator_registry.register('csr_div', csr_div)
tensor_operator_registry.register('csr_mv', csr_mv)
tensor_operator_registry.register('csr_mm', _csr_ops.CSRMM)
tensor_operator_registry.register('csr_mm_akg', _csr_ops.CSRMM)
tensor_operator_registry.register('csr_mm', csr_mm)
tensor_operator_registry.register('csr_reduce_sum', csr_reduce_sum)
tensor_operator_registry.register('dense_to_sparse_csr', dense_to_sparse_csr)
tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)