diff --git a/docs/api/api_python/mindspore.ops.function.rst b/docs/api/api_python/mindspore.ops.function.rst index fa9b4e18a77..1ac3fb09972 100644 --- a/docs/api/api_python/mindspore.ops.function.rst +++ b/docs/api/api_python/mindspore.ops.function.rst @@ -418,6 +418,7 @@ Array操作 mindspore.ops.dense_to_sparse_coo mindspore.ops.dense_to_sparse_csr mindspore.ops.csr_add + mindspore.ops.csr_mm mindspore.ops.csr_softmax mindspore.ops.csr_to_coo mindspore.ops.sparse_add diff --git a/docs/api/api_python/mindspore/mindspore.CSRTensor.rst b/docs/api/api_python/mindspore/mindspore.CSRTensor.rst index 0062cad914a..9e80f536375 100644 --- a/docs/api/api_python/mindspore/mindspore.CSRTensor.rst +++ b/docs/api/api_python/mindspore/mindspore.CSRTensor.rst @@ -78,19 +78,19 @@ mindspore.CSRTensor 返回每个非零元素所占字节数。 - .. py:method:: mm(dense_matrix: Tensor) + .. py:method:: mm(matrix: Union[Tensor, CSRTensor]) - 返回CSRTensor右乘稠密矩阵的矩阵乘法运算结果。 - shape为 `[M, N]` 的CSRTensor,需要适配shape为 `[N, K]` 的稠密矩阵,得到结果为 `[M, K]` 的稠密矩阵。 + 返回CSRTensor右乘稀疏矩阵或稠密矩阵的矩阵乘法运算结果。 + shape为 `[M, N]` 的CSRTensor,需要适配shape为 `[N, K]` 的稠密矩阵或稀疏矩阵,得到结果为 `[M, K]` 的稠密矩阵或稀疏矩阵。 .. note:: 如果运行后端是CPU,那么仅支持在安装了LLVM12.0.1的机器运行。 参数: - - **dense_matrix** (Tensor) - shape为 `[N,K]` 的二维矩阵,其中N等于CSRTensor的列数。 + - **matrix** (Tensor or CSRTensor) - shape为 `[N,K]` 的二维矩阵,其中N等于CSRTensor的列数。 返回: - Tensor。 + Tensor or CSRTensor。 .. py:method:: mv(dense_vector: Tensor) diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_mm.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_mm.rst new file mode 100644 index 00000000000..50ce3a20936 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_csr_mm.rst @@ -0,0 +1,21 @@ +mindspore.ops.csr_mm +================================= + +.. py:function:: mindspore.ops.csr_mm(a: CSRTensor, b: CSRTensor or Tensor, trans_a: bool, trans_b: bool, adjoint_a: bool, adjoint_b: bool) + + 返回稀疏矩阵a与稀疏矩阵或稠密矩阵b的矩阵乘法结果。 + + .. note:: + 若右矩阵为Tensor,则仅支持安装了LLVM12.0.1的CPU后端或GPU后端。 + 若右矩阵为CSRTensor, 则仅支持GPU后端。 + + 参数: + - **a** (CSRTensor) - 稀疏的 CSRTensor。 + - **b** (CSRTensor 或 Tensor) - 稀疏的 CSRTensor或稠密矩阵。 + - **trans_a** (Tensor) - 是否对矩阵a进行转置。 + - **trans_b** (Tensor) - 是否对矩阵b进行转置。 + - **adjoint_a** (Tensor) - 是否对矩阵a进行共轭。 + - **adjoint_b** (Tensor) - 是否对矩阵b进行共轭。 + + 返回: + 返回稀疏矩阵,类型为CSRTensor。 \ No newline at end of file diff --git a/docs/api/api_python_en/mindspore.ops.function.rst b/docs/api/api_python_en/mindspore.ops.function.rst index 56285fe4616..69bc4bc2e30 100644 --- a/docs/api/api_python_en/mindspore.ops.function.rst +++ b/docs/api/api_python_en/mindspore.ops.function.rst @@ -418,6 +418,7 @@ Sparse Functions mindspore.ops.dense_to_sparse_coo mindspore.ops.dense_to_sparse_csr mindspore.ops.csr_add + mindspore.ops.csr_mm mindspore.ops.csr_softmax mindspore.ops.csr_to_coo mindspore.ops.sparse_add diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index 8c3b41675e8..85111dd8bda 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -36,8 +36,8 @@ from ...ops.composite.multitype_ops import _constexpr_utils as const_utils from ...ops.composite.multitype_ops import _compile_utils as compile_utils from ...ops.operations.math_ops import Median from ...ops.operations._inner_ops import Format -from ...ops.operations import _map_tensor_ops from ...ops.operations import _csr_ops +from ...ops.operations import _map_tensor_ops from ...ops.primitive import constexpr from ...common import dtype as mstype @@ -3178,9 +3178,11 @@ def csr_mv(x, dense_vector): return F.csr_mv(x, dense_vector) -def csr_mm(x, dense): +def csr_mm(x, matrix): """Implementation of `mm` for CSRTensor.""" - return _csr_mm(x.indptr, x.indices, x.values, x.shape, dense) + if isinstance(matrix, CSRTensor): + return F.csr_mm(x, matrix) + return _csr_mm(x.indptr, x.indices, x.values, x.shape, matrix) def csr_to_tuple(x): diff --git a/mindspore/python/mindspore/common/sparse_tensor.py b/mindspore/python/mindspore/common/sparse_tensor.py index 4812e88815b..a9b580edd2d 100644 --- a/mindspore/python/mindspore/common/sparse_tensor.py +++ b/mindspore/python/mindspore/common/sparse_tensor.py @@ -775,23 +775,25 @@ class CSRTensor(CSRTensor_): validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv') return tensor_operator_registry.get("csr_mv")(self, dense_vector) - def mm(self, dense_matrix: Tensor) -> Tensor: + def mm(self, matrix: Union[Tensor, CSRTensor]) -> Union[Tensor, CSRTensor]: """ - Return the matrix multiplication result of the right-multiply dense matrix of the CSRTensor. - The CSRTensor with shape `[M, N]` needs to adapt the dense matrix with shape `[N, K]` - to get the dense matrix with result `[M, K]`. + Return the matrix multiplication result of the right-multiply matrix(dense or CSRTensor) of the CSRTensor. + The CSRTensor with shape `[M, N]` needs to adapt the right matrix with shape `[N, K]` + to get the dense matrix or CSRTensor with result `[M, K]`. Note: - Currently only supports CPU backend with LLVM 12.0.1 installed. + If right matrix is CSRTensor, currently only supports GPU backend. + if right matrix is Tensor, currently supports CPU backend with LLVM 12.0.1 or GPU backend. Args: - dense_matrix (Tensor): A dense Tensor, its shape[0] should be equal to csr_tensor.shape[1] + matrix (Tensor or CSRTensor): A dense Tensor or CSRTensor, + its shape[0] should be equal to csr_tensor.shape[1] Returns: - Tensor. + Tensor or CSRTensor. Supported Platforms: - ``GPU`` ``CPU`` + ``CPU`` ``GPU`` Examples: >>> from mindspore import Tensor, CSRTensor @@ -806,9 +808,11 @@ class CSRTensor(CSRTensor_): [[2. 4.] [1. 2.]] """ - validator.check_value_type('dense_matrix', dense_matrix, (Tensor_,), 'CSRTensor.mm') - return tensor_operator_registry.get("csr_mm")()(self.indptr, self.indices, self.values, - self.shape, dense_matrix) + if isinstance(matrix, CSRTensor): + return tensor_operator_registry.get("csr_mm")(self, matrix) + validator.check_value_type('matrix', matrix, (Tensor_,), 'CSRTensor.mm') + return tensor_operator_registry.get("csr_mm_akg")()(self.indptr, self.indices, self.values, + self.shape, matrix) def sum(self, axis: int) -> Tensor: """ diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index b6d2f6ea9d5..cce4e6cbb4d 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -376,6 +376,7 @@ from .sparse_func import ( csr_gather, csr_mul, csr_mv, + csr_mm, csr_reduce_sum, csr_to_coo, csr2coo, diff --git a/mindspore/python/mindspore/ops/function/sparse_func.py b/mindspore/python/mindspore/ops/function/sparse_func.py index 5803e21ea00..0fa3df65ba4 100644 --- a/mindspore/python/mindspore/ops/function/sparse_func.py +++ b/mindspore/python/mindspore/ops/function/sparse_func.py @@ -24,6 +24,7 @@ from mindspore.ops.operations.sparse_ops import ( SparseAdd, SparseMatrixAdd, SparseMatrixSoftmax, + SparseMatrixSparseMatMul, CSRSparseMatrixToDense ) from mindspore.common import dtype as mstype @@ -227,6 +228,84 @@ def csr_mv(csr_tensor: CSRTensor, dense: Tensor) -> Tensor: return _csr_ops.CSRMV()(csr_tensor.indptr, csr_tensor.indices, csr_tensor_values, csr_tensor.shape, dense) +def csr_mm(a: CSRTensor, b: CSRTensor, trans_a: bool = False, trans_b: bool = False, + adjoint_a: bool = False, adjoint_b: bool = False): + """ + Return the matrix multiplication result of the right-multiply matrix(dense or CSRTensor) of the CSRTensor. + The CSRTensor with shape `[M, N]` needs to adapt the right matrix with shape `[N, K]` + to get the dense matrix or CSRTensor with result `[M, K]`. + + Note: + Currently supports GPU backend with right matrix is CSRTensor. + + Args: + a (CSRTensor): Sparse CSR Tensor, rank should be 2. + b (CSRTensor): Sparse CSR Tensor, rank should be 2. + trans_a (bool): whether to transpose CSRTensor a. + trans_a (bool): whether to transpose CSRTensor b. + adjoint_a (bool): whether to adjoint CSRTensor a. + adjoint_b (bool): whether to adjoint CSRTensor b. + + Returns: + CSRTensor. + + Supported Platforms: + ``GPU`` + + Examples: + >>> from mindspore import Tensor, CSRTensor + >>> from mindspore import dtype as mstype + >>> a_shape = (4, 5) + >>> a_indptr = Tensor([0, 1, 1, 3, 4], dtype=mstype.int32) + >>> a_indices = Tensor([0, 3, 4, 0],dtype=mstype.int32) + >>> a_values = Tensor([1.0, 5.0, -1.0, -2.0], dtype=mstype.float32) + >>> b_shape = (5, 3) + >>> b_indptr = Tensor([0, 1, 1, 3, 3, 3], dtype=mstype.int32) + >>> b_indices = Tensor([0, 0, 1],dtype=mstype.int32) + >>> b_values = Tensor([2.0, 7.0, 8.0], dtype=mstype.float32) + >>> a = CSRTensor(a_indptr, a_indices, a_values, a_shape) + >>> b = CSRTensor(b_indptr, b_indices, b_values, b_shape) + >>> c = csr_mm(a, b) + >>> print(c.shape) + (4, 3) + >>> print(c.values) + [2. -4.] + >>> print(c.indptr) + [0 1 1 1 2] + >>> print(c.indices) + [0 0] + """ + if isinstance(a, CSRTensor) and isinstance(b, CSRTensor): + a_batch_pointers = make_tensor([0, a.values.shape[0]], a.indices.dtype) + b_batch_pointers = make_tensor([0, b.values.shape[0]], b.indices.dtype) + a_shape = make_tensor(a.shape, a.indices.dtype) + b_shape = make_tensor(b.shape, b.indices.dtype) + sparse_matrix_sparse_matmul = SparseMatrixSparseMatMul(transpose_a=trans_a, + transpose_b=trans_b, + adjoint_a=adjoint_a, + adjoint_b=adjoint_b) + + _, _, c_indptr, c_indices, c_values = sparse_matrix_sparse_matmul(a_shape, + a_batch_pointers, + a.indptr, + a.indices, + a.values, + b_shape, + b_batch_pointers, + b.indptr, + b.indices, + b.values) + m, a2 = a.shape + b1, k = b.shape + if trans_a or adjoint_a: + m = a2 + if trans_b or adjoint_b: + k = b1 + return CSRTensor(c_indptr, c_indices, c_values, (m, k)) + raise_type_error("For functional operator csr_mm, inputs a and b must be type of CSRTensor currently.") + return None + + def csr_reduce_sum(csr_tensor: CSRTensor, axis: int) -> Tensor: """ Reduces a dimension of a CSRTensor by summing all elements in the dimension. diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index f134d7b9a17..bd445397692 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -28,11 +28,12 @@ from mindspore._checkparam import Rel from mindspore.ops import _constants from mindspore.ops.function import * from mindspore.ops.function.sparse_func import sparse_add +from mindspore.ops.function.sparse_func import csr_mm from mindspore.ops.primitive import constexpr, Primitive from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops -from mindspore.ops.operations import _csr_ops from mindspore.ops.operations import _inner_ops +from mindspore.ops.operations import _csr_ops from mindspore.ops.operations import linalg_ops from mindspore.ops.operations.math_ops import Median from mindspore.ops.operations.array_ops import UniqueConsecutive @@ -472,7 +473,8 @@ tensor_operator_registry.register('csr2coo', csr2coo) tensor_operator_registry.register('coo2csr', coo2csr) tensor_operator_registry.register('csr_div', csr_div) tensor_operator_registry.register('csr_mv', csr_mv) -tensor_operator_registry.register('csr_mm', _csr_ops.CSRMM) +tensor_operator_registry.register('csr_mm_akg', _csr_ops.CSRMM) +tensor_operator_registry.register('csr_mm', csr_mm) tensor_operator_registry.register('csr_reduce_sum', csr_reduce_sum) tensor_operator_registry.register('dense_to_sparse_csr', dense_to_sparse_csr) tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)