!44859 CSRTensor & functional SparseMatrixSparseMatMul
Merge pull request !44859 from mamba_ni/master
This commit is contained in:
commit
dfd2239118
|
@ -418,6 +418,7 @@ Array操作
|
||||||
mindspore.ops.dense_to_sparse_coo
|
mindspore.ops.dense_to_sparse_coo
|
||||||
mindspore.ops.dense_to_sparse_csr
|
mindspore.ops.dense_to_sparse_csr
|
||||||
mindspore.ops.csr_add
|
mindspore.ops.csr_add
|
||||||
|
mindspore.ops.csr_mm
|
||||||
mindspore.ops.csr_softmax
|
mindspore.ops.csr_softmax
|
||||||
mindspore.ops.csr_to_coo
|
mindspore.ops.csr_to_coo
|
||||||
mindspore.ops.sparse_add
|
mindspore.ops.sparse_add
|
||||||
|
|
|
@ -78,19 +78,19 @@ mindspore.CSRTensor
|
||||||
|
|
||||||
返回每个非零元素所占字节数。
|
返回每个非零元素所占字节数。
|
||||||
|
|
||||||
.. py:method:: mm(dense_matrix: Tensor)
|
.. py:method:: mm(matrix: Union[Tensor, CSRTensor])
|
||||||
|
|
||||||
返回CSRTensor右乘稠密矩阵的矩阵乘法运算结果。
|
返回CSRTensor右乘稀疏矩阵或稠密矩阵的矩阵乘法运算结果。
|
||||||
shape为 `[M, N]` 的CSRTensor,需要适配shape为 `[N, K]` 的稠密矩阵,得到结果为 `[M, K]` 的稠密矩阵。
|
shape为 `[M, N]` 的CSRTensor,需要适配shape为 `[N, K]` 的稠密矩阵或稀疏矩阵,得到结果为 `[M, K]` 的稠密矩阵或稀疏矩阵。
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
如果运行后端是CPU,那么仅支持在安装了LLVM12.0.1的机器运行。
|
如果运行后端是CPU,那么仅支持在安装了LLVM12.0.1的机器运行。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **dense_matrix** (Tensor) - shape为 `[N,K]` 的二维矩阵,其中N等于CSRTensor的列数。
|
- **matrix** (Tensor or CSRTensor) - shape为 `[N,K]` 的二维矩阵,其中N等于CSRTensor的列数。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
Tensor。
|
Tensor or CSRTensor。
|
||||||
|
|
||||||
.. py:method:: mv(dense_vector: Tensor)
|
.. py:method:: mv(dense_vector: Tensor)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
mindspore.ops.csr_mm
|
||||||
|
=================================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.ops.csr_mm(a: CSRTensor, b: CSRTensor or Tensor, trans_a: bool, trans_b: bool, adjoint_a: bool, adjoint_b: bool)
|
||||||
|
|
||||||
|
返回稀疏矩阵a与稀疏矩阵或稠密矩阵b的矩阵乘法结果。
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
若右矩阵为Tensor,则仅支持安装了LLVM12.0.1的CPU后端或GPU后端。
|
||||||
|
若右矩阵为CSRTensor, 则仅支持GPU后端。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **a** (CSRTensor) - 稀疏的 CSRTensor。
|
||||||
|
- **b** (CSRTensor 或 Tensor) - 稀疏的 CSRTensor或稠密矩阵。
|
||||||
|
- **trans_a** (Tensor) - 是否对矩阵a进行转置。
|
||||||
|
- **trans_b** (Tensor) - 是否对矩阵b进行转置。
|
||||||
|
- **adjoint_a** (Tensor) - 是否对矩阵a进行共轭。
|
||||||
|
- **adjoint_b** (Tensor) - 是否对矩阵b进行共轭。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
返回稀疏矩阵,类型为CSRTensor。
|
|
@ -418,6 +418,7 @@ Sparse Functions
|
||||||
mindspore.ops.dense_to_sparse_coo
|
mindspore.ops.dense_to_sparse_coo
|
||||||
mindspore.ops.dense_to_sparse_csr
|
mindspore.ops.dense_to_sparse_csr
|
||||||
mindspore.ops.csr_add
|
mindspore.ops.csr_add
|
||||||
|
mindspore.ops.csr_mm
|
||||||
mindspore.ops.csr_softmax
|
mindspore.ops.csr_softmax
|
||||||
mindspore.ops.csr_to_coo
|
mindspore.ops.csr_to_coo
|
||||||
mindspore.ops.sparse_add
|
mindspore.ops.sparse_add
|
||||||
|
|
|
@ -36,8 +36,8 @@ from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||||
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
|
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
|
||||||
from ...ops.operations.math_ops import Median
|
from ...ops.operations.math_ops import Median
|
||||||
from ...ops.operations._inner_ops import Format
|
from ...ops.operations._inner_ops import Format
|
||||||
from ...ops.operations import _map_tensor_ops
|
|
||||||
from ...ops.operations import _csr_ops
|
from ...ops.operations import _csr_ops
|
||||||
|
from ...ops.operations import _map_tensor_ops
|
||||||
from ...ops.primitive import constexpr
|
from ...ops.primitive import constexpr
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
|
|
||||||
|
@ -3178,9 +3178,11 @@ def csr_mv(x, dense_vector):
|
||||||
return F.csr_mv(x, dense_vector)
|
return F.csr_mv(x, dense_vector)
|
||||||
|
|
||||||
|
|
||||||
def csr_mm(x, dense):
|
def csr_mm(x, matrix):
|
||||||
"""Implementation of `mm` for CSRTensor."""
|
"""Implementation of `mm` for CSRTensor."""
|
||||||
return _csr_mm(x.indptr, x.indices, x.values, x.shape, dense)
|
if isinstance(matrix, CSRTensor):
|
||||||
|
return F.csr_mm(x, matrix)
|
||||||
|
return _csr_mm(x.indptr, x.indices, x.values, x.shape, matrix)
|
||||||
|
|
||||||
|
|
||||||
def csr_to_tuple(x):
|
def csr_to_tuple(x):
|
||||||
|
|
|
@ -775,23 +775,25 @@ class CSRTensor(CSRTensor_):
|
||||||
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
|
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
|
||||||
return tensor_operator_registry.get("csr_mv")(self, dense_vector)
|
return tensor_operator_registry.get("csr_mv")(self, dense_vector)
|
||||||
|
|
||||||
def mm(self, dense_matrix: Tensor) -> Tensor:
|
def mm(self, matrix: Union[Tensor, CSRTensor]) -> Union[Tensor, CSRTensor]:
|
||||||
"""
|
"""
|
||||||
Return the matrix multiplication result of the right-multiply dense matrix of the CSRTensor.
|
Return the matrix multiplication result of the right-multiply matrix(dense or CSRTensor) of the CSRTensor.
|
||||||
The CSRTensor with shape `[M, N]` needs to adapt the dense matrix with shape `[N, K]`
|
The CSRTensor with shape `[M, N]` needs to adapt the right matrix with shape `[N, K]`
|
||||||
to get the dense matrix with result `[M, K]`.
|
to get the dense matrix or CSRTensor with result `[M, K]`.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Currently only supports CPU backend with LLVM 12.0.1 installed.
|
If right matrix is CSRTensor, currently only supports GPU backend.
|
||||||
|
if right matrix is Tensor, currently supports CPU backend with LLVM 12.0.1 or GPU backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dense_matrix (Tensor): A dense Tensor, its shape[0] should be equal to csr_tensor.shape[1]
|
matrix (Tensor or CSRTensor): A dense Tensor or CSRTensor,
|
||||||
|
its shape[0] should be equal to csr_tensor.shape[1]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor.
|
Tensor or CSRTensor.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``GPU`` ``CPU``
|
``CPU`` ``GPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from mindspore import Tensor, CSRTensor
|
>>> from mindspore import Tensor, CSRTensor
|
||||||
|
@ -806,9 +808,11 @@ class CSRTensor(CSRTensor_):
|
||||||
[[2. 4.]
|
[[2. 4.]
|
||||||
[1. 2.]]
|
[1. 2.]]
|
||||||
"""
|
"""
|
||||||
validator.check_value_type('dense_matrix', dense_matrix, (Tensor_,), 'CSRTensor.mm')
|
if isinstance(matrix, CSRTensor):
|
||||||
return tensor_operator_registry.get("csr_mm")()(self.indptr, self.indices, self.values,
|
return tensor_operator_registry.get("csr_mm")(self, matrix)
|
||||||
self.shape, dense_matrix)
|
validator.check_value_type('matrix', matrix, (Tensor_,), 'CSRTensor.mm')
|
||||||
|
return tensor_operator_registry.get("csr_mm_akg")()(self.indptr, self.indices, self.values,
|
||||||
|
self.shape, matrix)
|
||||||
|
|
||||||
def sum(self, axis: int) -> Tensor:
|
def sum(self, axis: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -376,6 +376,7 @@ from .sparse_func import (
|
||||||
csr_gather,
|
csr_gather,
|
||||||
csr_mul,
|
csr_mul,
|
||||||
csr_mv,
|
csr_mv,
|
||||||
|
csr_mm,
|
||||||
csr_reduce_sum,
|
csr_reduce_sum,
|
||||||
csr_to_coo,
|
csr_to_coo,
|
||||||
csr2coo,
|
csr2coo,
|
||||||
|
|
|
@ -24,6 +24,7 @@ from mindspore.ops.operations.sparse_ops import (
|
||||||
SparseAdd,
|
SparseAdd,
|
||||||
SparseMatrixAdd,
|
SparseMatrixAdd,
|
||||||
SparseMatrixSoftmax,
|
SparseMatrixSoftmax,
|
||||||
|
SparseMatrixSparseMatMul,
|
||||||
CSRSparseMatrixToDense
|
CSRSparseMatrixToDense
|
||||||
)
|
)
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
|
@ -227,6 +228,84 @@ def csr_mv(csr_tensor: CSRTensor, dense: Tensor) -> Tensor:
|
||||||
return _csr_ops.CSRMV()(csr_tensor.indptr, csr_tensor.indices, csr_tensor_values, csr_tensor.shape, dense)
|
return _csr_ops.CSRMV()(csr_tensor.indptr, csr_tensor.indices, csr_tensor_values, csr_tensor.shape, dense)
|
||||||
|
|
||||||
|
|
||||||
|
def csr_mm(a: CSRTensor, b: CSRTensor, trans_a: bool = False, trans_b: bool = False,
|
||||||
|
adjoint_a: bool = False, adjoint_b: bool = False):
|
||||||
|
"""
|
||||||
|
Return the matrix multiplication result of the right-multiply matrix(dense or CSRTensor) of the CSRTensor.
|
||||||
|
The CSRTensor with shape `[M, N]` needs to adapt the right matrix with shape `[N, K]`
|
||||||
|
to get the dense matrix or CSRTensor with result `[M, K]`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Currently supports GPU backend with right matrix is CSRTensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (CSRTensor): Sparse CSR Tensor, rank should be 2.
|
||||||
|
b (CSRTensor): Sparse CSR Tensor, rank should be 2.
|
||||||
|
trans_a (bool): whether to transpose CSRTensor a.
|
||||||
|
trans_a (bool): whether to transpose CSRTensor b.
|
||||||
|
adjoint_a (bool): whether to adjoint CSRTensor a.
|
||||||
|
adjoint_b (bool): whether to adjoint CSRTensor b.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CSRTensor.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mindspore import Tensor, CSRTensor
|
||||||
|
>>> from mindspore import dtype as mstype
|
||||||
|
>>> a_shape = (4, 5)
|
||||||
|
>>> a_indptr = Tensor([0, 1, 1, 3, 4], dtype=mstype.int32)
|
||||||
|
>>> a_indices = Tensor([0, 3, 4, 0],dtype=mstype.int32)
|
||||||
|
>>> a_values = Tensor([1.0, 5.0, -1.0, -2.0], dtype=mstype.float32)
|
||||||
|
>>> b_shape = (5, 3)
|
||||||
|
>>> b_indptr = Tensor([0, 1, 1, 3, 3, 3], dtype=mstype.int32)
|
||||||
|
>>> b_indices = Tensor([0, 0, 1],dtype=mstype.int32)
|
||||||
|
>>> b_values = Tensor([2.0, 7.0, 8.0], dtype=mstype.float32)
|
||||||
|
>>> a = CSRTensor(a_indptr, a_indices, a_values, a_shape)
|
||||||
|
>>> b = CSRTensor(b_indptr, b_indices, b_values, b_shape)
|
||||||
|
>>> c = csr_mm(a, b)
|
||||||
|
>>> print(c.shape)
|
||||||
|
(4, 3)
|
||||||
|
>>> print(c.values)
|
||||||
|
[2. -4.]
|
||||||
|
>>> print(c.indptr)
|
||||||
|
[0 1 1 1 2]
|
||||||
|
>>> print(c.indices)
|
||||||
|
[0 0]
|
||||||
|
"""
|
||||||
|
if isinstance(a, CSRTensor) and isinstance(b, CSRTensor):
|
||||||
|
a_batch_pointers = make_tensor([0, a.values.shape[0]], a.indices.dtype)
|
||||||
|
b_batch_pointers = make_tensor([0, b.values.shape[0]], b.indices.dtype)
|
||||||
|
a_shape = make_tensor(a.shape, a.indices.dtype)
|
||||||
|
b_shape = make_tensor(b.shape, b.indices.dtype)
|
||||||
|
sparse_matrix_sparse_matmul = SparseMatrixSparseMatMul(transpose_a=trans_a,
|
||||||
|
transpose_b=trans_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b)
|
||||||
|
|
||||||
|
_, _, c_indptr, c_indices, c_values = sparse_matrix_sparse_matmul(a_shape,
|
||||||
|
a_batch_pointers,
|
||||||
|
a.indptr,
|
||||||
|
a.indices,
|
||||||
|
a.values,
|
||||||
|
b_shape,
|
||||||
|
b_batch_pointers,
|
||||||
|
b.indptr,
|
||||||
|
b.indices,
|
||||||
|
b.values)
|
||||||
|
m, a2 = a.shape
|
||||||
|
b1, k = b.shape
|
||||||
|
if trans_a or adjoint_a:
|
||||||
|
m = a2
|
||||||
|
if trans_b or adjoint_b:
|
||||||
|
k = b1
|
||||||
|
return CSRTensor(c_indptr, c_indices, c_values, (m, k))
|
||||||
|
raise_type_error("For functional operator csr_mm, inputs a and b must be type of CSRTensor currently.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def csr_reduce_sum(csr_tensor: CSRTensor, axis: int) -> Tensor:
|
def csr_reduce_sum(csr_tensor: CSRTensor, axis: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Reduces a dimension of a CSRTensor by summing all elements in the dimension.
|
Reduces a dimension of a CSRTensor by summing all elements in the dimension.
|
||||||
|
|
|
@ -28,11 +28,12 @@ from mindspore._checkparam import Rel
|
||||||
from mindspore.ops import _constants
|
from mindspore.ops import _constants
|
||||||
from mindspore.ops.function import *
|
from mindspore.ops.function import *
|
||||||
from mindspore.ops.function.sparse_func import sparse_add
|
from mindspore.ops.function.sparse_func import sparse_add
|
||||||
|
from mindspore.ops.function.sparse_func import csr_mm
|
||||||
from mindspore.ops.primitive import constexpr, Primitive
|
from mindspore.ops.primitive import constexpr, Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.operations import _grad_ops
|
from mindspore.ops.operations import _grad_ops
|
||||||
from mindspore.ops.operations import _csr_ops
|
|
||||||
from mindspore.ops.operations import _inner_ops
|
from mindspore.ops.operations import _inner_ops
|
||||||
|
from mindspore.ops.operations import _csr_ops
|
||||||
from mindspore.ops.operations import linalg_ops
|
from mindspore.ops.operations import linalg_ops
|
||||||
from mindspore.ops.operations.math_ops import Median
|
from mindspore.ops.operations.math_ops import Median
|
||||||
from mindspore.ops.operations.array_ops import UniqueConsecutive
|
from mindspore.ops.operations.array_ops import UniqueConsecutive
|
||||||
|
@ -472,7 +473,8 @@ tensor_operator_registry.register('csr2coo', csr2coo)
|
||||||
tensor_operator_registry.register('coo2csr', coo2csr)
|
tensor_operator_registry.register('coo2csr', coo2csr)
|
||||||
tensor_operator_registry.register('csr_div', csr_div)
|
tensor_operator_registry.register('csr_div', csr_div)
|
||||||
tensor_operator_registry.register('csr_mv', csr_mv)
|
tensor_operator_registry.register('csr_mv', csr_mv)
|
||||||
tensor_operator_registry.register('csr_mm', _csr_ops.CSRMM)
|
tensor_operator_registry.register('csr_mm_akg', _csr_ops.CSRMM)
|
||||||
|
tensor_operator_registry.register('csr_mm', csr_mm)
|
||||||
tensor_operator_registry.register('csr_reduce_sum', csr_reduce_sum)
|
tensor_operator_registry.register('csr_reduce_sum', csr_reduce_sum)
|
||||||
tensor_operator_registry.register('dense_to_sparse_csr', dense_to_sparse_csr)
|
tensor_operator_registry.register('dense_to_sparse_csr', dense_to_sparse_csr)
|
||||||
tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)
|
tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)
|
||||||
|
|
Loading…
Reference in New Issue