!44859 CSRTensor & functional SparseMatrixSparseMatMul
Merge pull request !44859 from mamba_ni/master
This commit is contained in:
commit
dfd2239118
|
@ -418,6 +418,7 @@ Array操作
|
|||
mindspore.ops.dense_to_sparse_coo
|
||||
mindspore.ops.dense_to_sparse_csr
|
||||
mindspore.ops.csr_add
|
||||
mindspore.ops.csr_mm
|
||||
mindspore.ops.csr_softmax
|
||||
mindspore.ops.csr_to_coo
|
||||
mindspore.ops.sparse_add
|
||||
|
|
|
@ -78,19 +78,19 @@ mindspore.CSRTensor
|
|||
|
||||
返回每个非零元素所占字节数。
|
||||
|
||||
.. py:method:: mm(dense_matrix: Tensor)
|
||||
.. py:method:: mm(matrix: Union[Tensor, CSRTensor])
|
||||
|
||||
返回CSRTensor右乘稠密矩阵的矩阵乘法运算结果。
|
||||
shape为 `[M, N]` 的CSRTensor,需要适配shape为 `[N, K]` 的稠密矩阵,得到结果为 `[M, K]` 的稠密矩阵。
|
||||
返回CSRTensor右乘稀疏矩阵或稠密矩阵的矩阵乘法运算结果。
|
||||
shape为 `[M, N]` 的CSRTensor,需要适配shape为 `[N, K]` 的稠密矩阵或稀疏矩阵,得到结果为 `[M, K]` 的稠密矩阵或稀疏矩阵。
|
||||
|
||||
.. note::
|
||||
如果运行后端是CPU,那么仅支持在安装了LLVM12.0.1的机器运行。
|
||||
|
||||
参数:
|
||||
- **dense_matrix** (Tensor) - shape为 `[N,K]` 的二维矩阵,其中N等于CSRTensor的列数。
|
||||
- **matrix** (Tensor or CSRTensor) - shape为 `[N,K]` 的二维矩阵,其中N等于CSRTensor的列数。
|
||||
|
||||
返回:
|
||||
Tensor。
|
||||
Tensor or CSRTensor。
|
||||
|
||||
.. py:method:: mv(dense_vector: Tensor)
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
mindspore.ops.csr_mm
|
||||
=================================
|
||||
|
||||
.. py:function:: mindspore.ops.csr_mm(a: CSRTensor, b: CSRTensor or Tensor, trans_a: bool, trans_b: bool, adjoint_a: bool, adjoint_b: bool)
|
||||
|
||||
返回稀疏矩阵a与稀疏矩阵或稠密矩阵b的矩阵乘法结果。
|
||||
|
||||
.. note::
|
||||
若右矩阵为Tensor,则仅支持安装了LLVM12.0.1的CPU后端或GPU后端。
|
||||
若右矩阵为CSRTensor, 则仅支持GPU后端。
|
||||
|
||||
参数:
|
||||
- **a** (CSRTensor) - 稀疏的 CSRTensor。
|
||||
- **b** (CSRTensor 或 Tensor) - 稀疏的 CSRTensor或稠密矩阵。
|
||||
- **trans_a** (Tensor) - 是否对矩阵a进行转置。
|
||||
- **trans_b** (Tensor) - 是否对矩阵b进行转置。
|
||||
- **adjoint_a** (Tensor) - 是否对矩阵a进行共轭。
|
||||
- **adjoint_b** (Tensor) - 是否对矩阵b进行共轭。
|
||||
|
||||
返回:
|
||||
返回稀疏矩阵,类型为CSRTensor。
|
|
@ -418,6 +418,7 @@ Sparse Functions
|
|||
mindspore.ops.dense_to_sparse_coo
|
||||
mindspore.ops.dense_to_sparse_csr
|
||||
mindspore.ops.csr_add
|
||||
mindspore.ops.csr_mm
|
||||
mindspore.ops.csr_softmax
|
||||
mindspore.ops.csr_to_coo
|
||||
mindspore.ops.sparse_add
|
||||
|
|
|
@ -36,8 +36,8 @@ from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|||
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
|
||||
from ...ops.operations.math_ops import Median
|
||||
from ...ops.operations._inner_ops import Format
|
||||
from ...ops.operations import _map_tensor_ops
|
||||
from ...ops.operations import _csr_ops
|
||||
from ...ops.operations import _map_tensor_ops
|
||||
from ...ops.primitive import constexpr
|
||||
from ...common import dtype as mstype
|
||||
|
||||
|
@ -3178,9 +3178,11 @@ def csr_mv(x, dense_vector):
|
|||
return F.csr_mv(x, dense_vector)
|
||||
|
||||
|
||||
def csr_mm(x, dense):
|
||||
def csr_mm(x, matrix):
|
||||
"""Implementation of `mm` for CSRTensor."""
|
||||
return _csr_mm(x.indptr, x.indices, x.values, x.shape, dense)
|
||||
if isinstance(matrix, CSRTensor):
|
||||
return F.csr_mm(x, matrix)
|
||||
return _csr_mm(x.indptr, x.indices, x.values, x.shape, matrix)
|
||||
|
||||
|
||||
def csr_to_tuple(x):
|
||||
|
|
|
@ -775,23 +775,25 @@ class CSRTensor(CSRTensor_):
|
|||
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
|
||||
return tensor_operator_registry.get("csr_mv")(self, dense_vector)
|
||||
|
||||
def mm(self, dense_matrix: Tensor) -> Tensor:
|
||||
def mm(self, matrix: Union[Tensor, CSRTensor]) -> Union[Tensor, CSRTensor]:
|
||||
"""
|
||||
Return the matrix multiplication result of the right-multiply dense matrix of the CSRTensor.
|
||||
The CSRTensor with shape `[M, N]` needs to adapt the dense matrix with shape `[N, K]`
|
||||
to get the dense matrix with result `[M, K]`.
|
||||
Return the matrix multiplication result of the right-multiply matrix(dense or CSRTensor) of the CSRTensor.
|
||||
The CSRTensor with shape `[M, N]` needs to adapt the right matrix with shape `[N, K]`
|
||||
to get the dense matrix or CSRTensor with result `[M, K]`.
|
||||
|
||||
Note:
|
||||
Currently only supports CPU backend with LLVM 12.0.1 installed.
|
||||
If right matrix is CSRTensor, currently only supports GPU backend.
|
||||
if right matrix is Tensor, currently supports CPU backend with LLVM 12.0.1 or GPU backend.
|
||||
|
||||
Args:
|
||||
dense_matrix (Tensor): A dense Tensor, its shape[0] should be equal to csr_tensor.shape[1]
|
||||
matrix (Tensor or CSRTensor): A dense Tensor or CSRTensor,
|
||||
its shape[0] should be equal to csr_tensor.shape[1]
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
Tensor or CSRTensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
|
@ -806,9 +808,11 @@ class CSRTensor(CSRTensor_):
|
|||
[[2. 4.]
|
||||
[1. 2.]]
|
||||
"""
|
||||
validator.check_value_type('dense_matrix', dense_matrix, (Tensor_,), 'CSRTensor.mm')
|
||||
return tensor_operator_registry.get("csr_mm")()(self.indptr, self.indices, self.values,
|
||||
self.shape, dense_matrix)
|
||||
if isinstance(matrix, CSRTensor):
|
||||
return tensor_operator_registry.get("csr_mm")(self, matrix)
|
||||
validator.check_value_type('matrix', matrix, (Tensor_,), 'CSRTensor.mm')
|
||||
return tensor_operator_registry.get("csr_mm_akg")()(self.indptr, self.indices, self.values,
|
||||
self.shape, matrix)
|
||||
|
||||
def sum(self, axis: int) -> Tensor:
|
||||
"""
|
||||
|
|
|
@ -376,6 +376,7 @@ from .sparse_func import (
|
|||
csr_gather,
|
||||
csr_mul,
|
||||
csr_mv,
|
||||
csr_mm,
|
||||
csr_reduce_sum,
|
||||
csr_to_coo,
|
||||
csr2coo,
|
||||
|
|
|
@ -24,6 +24,7 @@ from mindspore.ops.operations.sparse_ops import (
|
|||
SparseAdd,
|
||||
SparseMatrixAdd,
|
||||
SparseMatrixSoftmax,
|
||||
SparseMatrixSparseMatMul,
|
||||
CSRSparseMatrixToDense
|
||||
)
|
||||
from mindspore.common import dtype as mstype
|
||||
|
@ -227,6 +228,84 @@ def csr_mv(csr_tensor: CSRTensor, dense: Tensor) -> Tensor:
|
|||
return _csr_ops.CSRMV()(csr_tensor.indptr, csr_tensor.indices, csr_tensor_values, csr_tensor.shape, dense)
|
||||
|
||||
|
||||
def csr_mm(a: CSRTensor, b: CSRTensor, trans_a: bool = False, trans_b: bool = False,
|
||||
adjoint_a: bool = False, adjoint_b: bool = False):
|
||||
"""
|
||||
Return the matrix multiplication result of the right-multiply matrix(dense or CSRTensor) of the CSRTensor.
|
||||
The CSRTensor with shape `[M, N]` needs to adapt the right matrix with shape `[N, K]`
|
||||
to get the dense matrix or CSRTensor with result `[M, K]`.
|
||||
|
||||
Note:
|
||||
Currently supports GPU backend with right matrix is CSRTensor.
|
||||
|
||||
Args:
|
||||
a (CSRTensor): Sparse CSR Tensor, rank should be 2.
|
||||
b (CSRTensor): Sparse CSR Tensor, rank should be 2.
|
||||
trans_a (bool): whether to transpose CSRTensor a.
|
||||
trans_a (bool): whether to transpose CSRTensor b.
|
||||
adjoint_a (bool): whether to adjoint CSRTensor a.
|
||||
adjoint_b (bool): whether to adjoint CSRTensor b.
|
||||
|
||||
Returns:
|
||||
CSRTensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> a_shape = (4, 5)
|
||||
>>> a_indptr = Tensor([0, 1, 1, 3, 4], dtype=mstype.int32)
|
||||
>>> a_indices = Tensor([0, 3, 4, 0],dtype=mstype.int32)
|
||||
>>> a_values = Tensor([1.0, 5.0, -1.0, -2.0], dtype=mstype.float32)
|
||||
>>> b_shape = (5, 3)
|
||||
>>> b_indptr = Tensor([0, 1, 1, 3, 3, 3], dtype=mstype.int32)
|
||||
>>> b_indices = Tensor([0, 0, 1],dtype=mstype.int32)
|
||||
>>> b_values = Tensor([2.0, 7.0, 8.0], dtype=mstype.float32)
|
||||
>>> a = CSRTensor(a_indptr, a_indices, a_values, a_shape)
|
||||
>>> b = CSRTensor(b_indptr, b_indices, b_values, b_shape)
|
||||
>>> c = csr_mm(a, b)
|
||||
>>> print(c.shape)
|
||||
(4, 3)
|
||||
>>> print(c.values)
|
||||
[2. -4.]
|
||||
>>> print(c.indptr)
|
||||
[0 1 1 1 2]
|
||||
>>> print(c.indices)
|
||||
[0 0]
|
||||
"""
|
||||
if isinstance(a, CSRTensor) and isinstance(b, CSRTensor):
|
||||
a_batch_pointers = make_tensor([0, a.values.shape[0]], a.indices.dtype)
|
||||
b_batch_pointers = make_tensor([0, b.values.shape[0]], b.indices.dtype)
|
||||
a_shape = make_tensor(a.shape, a.indices.dtype)
|
||||
b_shape = make_tensor(b.shape, b.indices.dtype)
|
||||
sparse_matrix_sparse_matmul = SparseMatrixSparseMatMul(transpose_a=trans_a,
|
||||
transpose_b=trans_b,
|
||||
adjoint_a=adjoint_a,
|
||||
adjoint_b=adjoint_b)
|
||||
|
||||
_, _, c_indptr, c_indices, c_values = sparse_matrix_sparse_matmul(a_shape,
|
||||
a_batch_pointers,
|
||||
a.indptr,
|
||||
a.indices,
|
||||
a.values,
|
||||
b_shape,
|
||||
b_batch_pointers,
|
||||
b.indptr,
|
||||
b.indices,
|
||||
b.values)
|
||||
m, a2 = a.shape
|
||||
b1, k = b.shape
|
||||
if trans_a or adjoint_a:
|
||||
m = a2
|
||||
if trans_b or adjoint_b:
|
||||
k = b1
|
||||
return CSRTensor(c_indptr, c_indices, c_values, (m, k))
|
||||
raise_type_error("For functional operator csr_mm, inputs a and b must be type of CSRTensor currently.")
|
||||
return None
|
||||
|
||||
|
||||
def csr_reduce_sum(csr_tensor: CSRTensor, axis: int) -> Tensor:
|
||||
"""
|
||||
Reduces a dimension of a CSRTensor by summing all elements in the dimension.
|
||||
|
|
|
@ -28,11 +28,12 @@ from mindspore._checkparam import Rel
|
|||
from mindspore.ops import _constants
|
||||
from mindspore.ops.function import *
|
||||
from mindspore.ops.function.sparse_func import sparse_add
|
||||
from mindspore.ops.function.sparse_func import csr_mm
|
||||
from mindspore.ops.primitive import constexpr, Primitive
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops
|
||||
from mindspore.ops.operations import _csr_ops
|
||||
from mindspore.ops.operations import _inner_ops
|
||||
from mindspore.ops.operations import _csr_ops
|
||||
from mindspore.ops.operations import linalg_ops
|
||||
from mindspore.ops.operations.math_ops import Median
|
||||
from mindspore.ops.operations.array_ops import UniqueConsecutive
|
||||
|
@ -472,7 +473,8 @@ tensor_operator_registry.register('csr2coo', csr2coo)
|
|||
tensor_operator_registry.register('coo2csr', coo2csr)
|
||||
tensor_operator_registry.register('csr_div', csr_div)
|
||||
tensor_operator_registry.register('csr_mv', csr_mv)
|
||||
tensor_operator_registry.register('csr_mm', _csr_ops.CSRMM)
|
||||
tensor_operator_registry.register('csr_mm_akg', _csr_ops.CSRMM)
|
||||
tensor_operator_registry.register('csr_mm', csr_mm)
|
||||
tensor_operator_registry.register('csr_reduce_sum', csr_reduce_sum)
|
||||
tensor_operator_registry.register('dense_to_sparse_csr', dense_to_sparse_csr)
|
||||
tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)
|
||||
|
|
Loading…
Reference in New Issue