modify sparsematrixsoftmax api
This commit is contained in:
parent
cf07fbc41b
commit
ea16266ddb
|
@ -473,6 +473,7 @@ Array操作
|
|||
|
||||
mindspore.ops.dense_to_sparse_coo
|
||||
mindspore.ops.dense_to_sparse_csr
|
||||
mindspore.ops.csr_softmax
|
||||
mindspore.ops.csr_to_coo
|
||||
mindspore.ops.sparse_add
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
mindspore.ops.csr_softmax
|
||||
=================================
|
||||
|
||||
.. py:function:: mindspore.ops.csr_softmax(logits, dtype)
|
||||
|
||||
计算 CSRTensorMatrix 的 softmax 。
|
||||
|
||||
参数:
|
||||
- **logits** (CSRTensor) - 输入稀疏的 CSRTensor。
|
||||
- **dtype** (dtype) - 输入的数据类型。
|
||||
|
||||
返回:
|
||||
- **CSRTensor** (CSRTensor) - 一个 csr_tensor 包含
|
||||
|
||||
- **indptr** - 指示每行中非零值的起始点和结束点。
|
||||
- **indices** - 输入中所有非零值的列位置。
|
||||
- **values** - 稠密张量的非零值。
|
||||
- **shape** - csrtensor 的形状.
|
||||
|
|
@ -476,6 +476,7 @@ Sparse Functions
|
|||
mindspore.ops.dense_to_sparse_coo
|
||||
mindspore.ops.dense_to_sparse_csr
|
||||
mindspore.ops.csr_to_coo
|
||||
mindspore.ops.csr_softmax
|
||||
mindspore.ops.sparse_add
|
||||
|
||||
Gradient Clipping
|
||||
|
|
|
@ -518,11 +518,12 @@ def csr_softmax(logits, dtype):
|
|||
dtype (dtype): Data type.
|
||||
|
||||
Returns:
|
||||
CSRTensor. a csr_tensor containing:
|
||||
indptr: indicates the start and end point for `values` in each row.
|
||||
indices: the column positions of all non-zero values of the input.
|
||||
values: the non-zero values of the dense tensor.
|
||||
shape: the shape of the csr_tensor.
|
||||
CSRTensor, a csr_tensor containing
|
||||
|
||||
- **indptr** - indicates the start and end point for `values` in each row.
|
||||
- **indices** - the column positions of all non-zero values of the input.
|
||||
- **values** - the non-zero values of the dense tensor.
|
||||
- **shape** - the shape of the csr_tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
@ -531,19 +532,18 @@ def csr_softmax(logits, dtype):
|
|||
>>> import mindspore as ms
|
||||
>>> import mindspore.common.dtype as mstype
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
>>> from mindspore.ops.functional import sparse_matrix_softmax
|
||||
>>> from mindspore.ops.function import csr_softmax
|
||||
>>> logits_indptr = Tensor([0, 4, 6], dtype=mstype.int32)
|
||||
>>> logits_indices = Tensor([0, 2, 3, 4, 3, 4], dtype=mstype.int32)
|
||||
>>> logits_values = Tensor([1, 2, 3, 4, 1, 2], dtype=mstype.float32)
|
||||
>>> shape = (2, 6)
|
||||
>>> logits = CSRTensor(logits_indptr, logits_indices, logits_values, shape)
|
||||
>>> out = logits.sparse_matrix_softmax(dtype)
|
||||
>>> out = csr_softmax(logits, dtype=mstype.float32)
|
||||
>>> print(out)
|
||||
CSRTensor(shape=[2,6], dtype=Float32,
|
||||
indptr=Tensor(shape=[3], dtype=Int64, value = [0, 4, 6]),
|
||||
indices=Tensor(shape=[2], dtype=Int64, value = [0, 2, 3, 4, 3, 4]),
|
||||
values=Tensor(shape=[2], dtype=Float32,
|
||||
value = [3.2058e-02, 8.7144e-02, 2.3688e-01, 6.4391e-01, 2.6894e-01, 7.310e-01]))
|
||||
CSRTensor(shape=[2, 6], dtype=Float32, indptr=Tensor(shape=[3], dtype=Int32, value=[0 4 6]),
|
||||
indices=Tensor(shape=[6], dtype=Int32, value=[0 2 3 4 3 4]),
|
||||
values=Tensor(shape=[6], dtype=Float32, value=[ 3.20586003e-02 8.71443152e-02 2.36882806e-01
|
||||
6.43914223e-01 2.68941432e-01 7.31058598e-01]))
|
||||
"""
|
||||
if not isinstance(logits, CSRTensor):
|
||||
raise_type_error("For functional operator sparse_matrix_softmax, logits must be type of CSRTensor.")
|
||||
|
|
|
@ -715,16 +715,25 @@ class SparseMatrixSoftmax(Primitive):
|
|||
"""
|
||||
Calculates the softmax of a CSRTensorMatrix.
|
||||
|
||||
Args:
|
||||
logits (CSRTensor): Sparse CSR Tensor.
|
||||
dtype (dtype): Data type.
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Returns:
|
||||
CSRTensor. a csr_tensor containing:
|
||||
indptr: indicates the start and end point for `values` in each row.
|
||||
indices: the column positions of all non-zero values of the input.
|
||||
values: the non-zero values of the dense tensor.
|
||||
shape: the shape of the csr_tensor.
|
||||
Args:
|
||||
dtype (dtype.Number) - The valid data type. Only constant value is allowed.
|
||||
|
||||
Inputs:
|
||||
- **x_dense_shape** (Tensor) - Input shape of the original Dense matrix.
|
||||
- **x_batch_pointers** (Tensor) - The number of rows in the input matrix.
|
||||
- **x_row_pointers** (Tensor) - Input the column coordinates of nonzero elements.
|
||||
- **x_col_indices** (Tensor) - The number of input nonzero elements up to that line.
|
||||
- **x_values** (Tensor) - The value of the input nonzero element.
|
||||
|
||||
Outputs:
|
||||
- **y_dense_shape** (Tensor) - Output shape of the original Dense matrix.
|
||||
- **y_batch_pointers** (Tensor) - The number of rows in the output matrix.
|
||||
- **y_row_pointers** (Tensor) - Output the column coordinates of nonzero elements.
|
||||
- **y_col_indices** (Tensor) - The number of output nonzero elements up to that line.
|
||||
- **y_values** (Tensor) - The value of the input nonzero element.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
@ -733,21 +742,27 @@ class SparseMatrixSoftmax(Primitive):
|
|||
>>> import mindspore as ms
|
||||
>>> import mindspore.common.dtype as mstype
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
>>> from mindspore.ops.functional import sparse_matrix_softmax
|
||||
>>> from mindspore.ops.operations.sparse_ops import SparseMatrixSoftmax
|
||||
>>> logits_indptr = Tensor([0, 4, 6], dtype=mstype.int32)
|
||||
>>> logits_indices = Tensor([0, 2, 3, 4, 3, 4], dtype=mstype.int32)
|
||||
>>> logits_values = Tensor([1, 2, 3, 4, 1, 2], dtype=mstype.float32)
|
||||
>>> shape = (2, 6)
|
||||
>>> logits = CSRTensor(logits_indptr, logits_indices, logits_values, shape)
|
||||
>>> out = logits.sparse_matrix_softmax(dtype)
|
||||
>>> net = SparseMatrixSoftmax(mstype.float32)
|
||||
>>> logits_pointers =Tensor(logits.values.shape[0], mstype.int32)
|
||||
>>> out = net(Tensor(logits.shape, dtype=mstype.int32), logits_pointers,
|
||||
... logits.indptr, logits.indices, logits.values)
|
||||
>>> print(out)
|
||||
CSRTensor(shape=[2,6], dtype=Float32,
|
||||
indptr=Tensor(shape=[3], dtype=Int64, value = [0, 4, 6]),
|
||||
indices=Tensor(shape=[2], dtype=Int64, value = [0, 2, 3, 4, 3, 4]),
|
||||
values=Tensor(shape=[2], dtype=Float32,
|
||||
value = [3.2058e-02, 8.7144e-02, 2.3688e-01, 6.4391e-01, 2.6894e-01, 7.310e-01]))
|
||||
(Tensor(shape=[2], dtype=Int32, value= [2, 6]),
|
||||
Tensor(shape=[], dtype=Int32, value= 6),
|
||||
Tensor(shape=[3], dtype=Int32, value= [0, 4, 6]),
|
||||
Tensor(shape=[6], dtype=Int32, value= [0, 2, 3, 4, 3, 4]),
|
||||
Tensor(shape=[6], dtype=Float32, value= [ 3.20586003e-02, 8.71443152e-02,
|
||||
2.36882806e-01, 6.43914223e-01, 2.68941432e-01, 7.31058598e-01]))
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, dtype):
|
||||
'''Initialize for SparseMatrixSoftmax'''
|
||||
|
|
Loading…
Reference in New Issue