modify sparsematrixsoftmax api

This commit is contained in:
chenghaowang 2022-08-02 17:26:24 +08:00 committed by cheng-hao-wang
parent cf07fbc41b
commit ea16266ddb
5 changed files with 64 additions and 28 deletions

View File

@ -473,6 +473,7 @@ Array操作
mindspore.ops.dense_to_sparse_coo
mindspore.ops.dense_to_sparse_csr
mindspore.ops.csr_softmax
mindspore.ops.csr_to_coo
mindspore.ops.sparse_add

View File

@ -0,0 +1,19 @@
mindspore.ops.csr_softmax
=================================
.. py:function:: mindspore.ops.csr_softmax(logits, dtype)
计算 CSRTensorMatrix 的 softmax 。
参数:
- **logits** CSRTensor - 输入稀疏的 CSRTensor。
- **dtype** dtype - 输入的数据类型。
返回:
- **CSRTensor** CSRTensor - 一个 csr_tensor 包含
- **indptr** - 指示每行中非零值的起始点和结束点。
- **indices** - 输入中所有非零值的列位置。
- **values** - 稠密张量的非零值。
- **shape** - csrtensor 的形状.

View File

@ -476,6 +476,7 @@ Sparse Functions
mindspore.ops.dense_to_sparse_coo
mindspore.ops.dense_to_sparse_csr
mindspore.ops.csr_to_coo
mindspore.ops.csr_softmax
mindspore.ops.sparse_add
Gradient Clipping

View File

@ -518,11 +518,12 @@ def csr_softmax(logits, dtype):
dtype (dtype): Data type.
Returns:
CSRTensor. a csr_tensor containing:
indptr: indicates the start and end point for `values` in each row.
indices: the column positions of all non-zero values of the input.
values: the non-zero values of the dense tensor.
shape: the shape of the csr_tensor.
CSRTensor, a csr_tensor containing
- **indptr** - indicates the start and end point for `values` in each row.
- **indices** - the column positions of all non-zero values of the input.
- **values** - the non-zero values of the dense tensor.
- **shape** - the shape of the csr_tensor.
Supported Platforms:
``GPU`` ``CPU``
@ -531,19 +532,18 @@ def csr_softmax(logits, dtype):
>>> import mindspore as ms
>>> import mindspore.common.dtype as mstype
>>> from mindspore import Tensor, CSRTensor
>>> from mindspore.ops.functional import sparse_matrix_softmax
>>> from mindspore.ops.function import csr_softmax
>>> logits_indptr = Tensor([0, 4, 6], dtype=mstype.int32)
>>> logits_indices = Tensor([0, 2, 3, 4, 3, 4], dtype=mstype.int32)
>>> logits_values = Tensor([1, 2, 3, 4, 1, 2], dtype=mstype.float32)
>>> shape = (2, 6)
>>> logits = CSRTensor(logits_indptr, logits_indices, logits_values, shape)
>>> out = logits.sparse_matrix_softmax(dtype)
>>> out = csr_softmax(logits, dtype=mstype.float32)
>>> print(out)
CSRTensor(shape=[2,6], dtype=Float32,
indptr=Tensor(shape=[3], dtype=Int64, value = [0, 4, 6]),
indices=Tensor(shape=[2], dtype=Int64, value = [0, 2, 3, 4, 3, 4]),
values=Tensor(shape=[2], dtype=Float32,
value = [3.2058e-02, 8.7144e-02, 2.3688e-01, 6.4391e-01, 2.6894e-01, 7.310e-01]))
CSRTensor(shape=[2, 6], dtype=Float32, indptr=Tensor(shape=[3], dtype=Int32, value=[0 4 6]),
indices=Tensor(shape=[6], dtype=Int32, value=[0 2 3 4 3 4]),
values=Tensor(shape=[6], dtype=Float32, value=[ 3.20586003e-02 8.71443152e-02 2.36882806e-01
6.43914223e-01 2.68941432e-01 7.31058598e-01]))
"""
if not isinstance(logits, CSRTensor):
raise_type_error("For functional operator sparse_matrix_softmax, logits must be type of CSRTensor.")

View File

@ -715,16 +715,25 @@ class SparseMatrixSoftmax(Primitive):
"""
Calculates the softmax of a CSRTensorMatrix.
Args:
logits (CSRTensor): Sparse CSR Tensor.
dtype (dtype): Data type.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Returns:
CSRTensor. a csr_tensor containing:
indptr: indicates the start and end point for `values` in each row.
indices: the column positions of all non-zero values of the input.
values: the non-zero values of the dense tensor.
shape: the shape of the csr_tensor.
Args:
dtype (dtype.Number) - The valid data type. Only constant value is allowed.
Inputs:
- **x_dense_shape** (Tensor) - Input shape of the original Dense matrix.
- **x_batch_pointers** (Tensor) - The number of rows in the input matrix.
- **x_row_pointers** (Tensor) - Input the column coordinates of nonzero elements.
- **x_col_indices** (Tensor) - The number of input nonzero elements up to that line.
- **x_values** (Tensor) - The value of the input nonzero element.
Outputs:
- **y_dense_shape** (Tensor) - Output shape of the original Dense matrix.
- **y_batch_pointers** (Tensor) - The number of rows in the output matrix.
- **y_row_pointers** (Tensor) - Output the column coordinates of nonzero elements.
- **y_col_indices** (Tensor) - The number of output nonzero elements up to that line.
- **y_values** (Tensor) - The value of the input nonzero element.
Supported Platforms:
``GPU`` ``CPU``
@ -733,21 +742,27 @@ class SparseMatrixSoftmax(Primitive):
>>> import mindspore as ms
>>> import mindspore.common.dtype as mstype
>>> from mindspore import Tensor, CSRTensor
>>> from mindspore.ops.functional import sparse_matrix_softmax
>>> from mindspore.ops.operations.sparse_ops import SparseMatrixSoftmax
>>> logits_indptr = Tensor([0, 4, 6], dtype=mstype.int32)
>>> logits_indices = Tensor([0, 2, 3, 4, 3, 4], dtype=mstype.int32)
>>> logits_values = Tensor([1, 2, 3, 4, 1, 2], dtype=mstype.float32)
>>> shape = (2, 6)
>>> logits = CSRTensor(logits_indptr, logits_indices, logits_values, shape)
>>> out = logits.sparse_matrix_softmax(dtype)
>>> net = SparseMatrixSoftmax(mstype.float32)
>>> logits_pointers =Tensor(logits.values.shape[0], mstype.int32)
>>> out = net(Tensor(logits.shape, dtype=mstype.int32), logits_pointers,
... logits.indptr, logits.indices, logits.values)
>>> print(out)
CSRTensor(shape=[2,6], dtype=Float32,
indptr=Tensor(shape=[3], dtype=Int64, value = [0, 4, 6]),
indices=Tensor(shape=[2], dtype=Int64, value = [0, 2, 3, 4, 3, 4]),
values=Tensor(shape=[2], dtype=Float32,
value = [3.2058e-02, 8.7144e-02, 2.3688e-01, 6.4391e-01, 2.6894e-01, 7.310e-01]))
(Tensor(shape=[2], dtype=Int32, value= [2, 6]),
Tensor(shape=[], dtype=Int32, value= 6),
Tensor(shape=[3], dtype=Int32, value= [0, 4, 6]),
Tensor(shape=[6], dtype=Int32, value= [0, 2, 3, 4, 3, 4]),
Tensor(shape=[6], dtype=Float32, value= [ 3.20586003e-02, 8.71443152e-02,
2.36882806e-01, 6.43914223e-01, 2.68941432e-01, 7.31058598e-01]))
"""
@prim_attr_register
def __init__(self, dtype):
'''Initialize for SparseMatrixSoftmax'''