!40319 Fix description bugs for SparseAddmm

Merge pull request !40319 from 孟权令/SparseAddmm
This commit is contained in:
i-robot 2022-08-15 01:51:54 +00:00 committed by Gitee
commit 5fab03f5ff
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 19 additions and 16 deletions

View File

@ -810,29 +810,32 @@ class Sspaddmm(Primitive):
class SparseAddmm(Primitive):
"""
Multiplies sparse matrix `A` by dense matrix `B` * `alpha` and add dense matrix `C` * `beta`.
The rank of sparse matrix and dense matrix must equal to `2`.
Multiplies sparse matrix `x1` by dense matrix `x2` * `alpha` and add dense matrix `x3` * `beta`.
The rank of sparse matrix and dense matrix must equal to `2`. The sparse matrix `x1` is formulated by `x1_indices`,
`x1_values` and `x1_shape`.
Inputs:
- **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`.
- **values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`.
- **x1_indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(N, 2)`.
- **x1_values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`.
Support float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64.
The shape should be :math:`(n,)`.
- **sparse_shape** (Tensor) - A positive int tuple which specifies the shape of sparse tensor.
Support int32, int64, should have 2 elements, represent sparse tensor shape is :math:`(N, C)`.
- **x2_dense** (Tensor) - A 2-D Tensor, the dtype is same as `values`.
- **x3_dense** (Tensor) - A 2-D Tensor, the dtype is same as `values`.
- **alpha** (Tensor) - A 1-D Tensor, the dtype is same as `values`.
- **beta** (Tensor) - A 1-D Tensor, the dtype is same as `values`.
The shape should be :math:`(N,)`.
- **x1_shape** (Tensor) - A positive int tuple which specifies the shape of sparse tensor.
Support int32, int64, should have 2 elements, represent sparse tensor shape is :math:`(Q, P)`.
- **x2** (Tensor) - A 2-D Dense Tensor, the dtype is same as `values`. The shape should be :math:`(P, M)`.
- **x3** (Tensor) - A 2-D Dense Tensor, the dtype is same as `values`. The shape should be :math:`(Q, M)`.
- **alpha** (Tensor) - A 1-D Tensor, the dtype is same as `values`. The shape should be :math:`(1,)`.
- **beta** (Tensor) - A 1-D Tensor, the dtype is same as `values`. The shape should be :math:`(1,)`.
Outputs:
Tensor, the dtype is the same as `values`.
Tensor, the dtype is the same as `x1_values`. The shape is the same as `x3`.
Raises:
TypeError: If dtype of `indices`, dtype of `values` and dtype of `dense` don't meet the parameter description.
ValueError: If `sparse_shape`, shape of `indices, shape of `values`, and shape of `dense` don't meet the
parameter description.
TypeError: If dtype of `x1_indices`, dtype of `x1_values` and dtype of `dense` don't meet the parameter
description.
ValueError: If shape of `x1_indices`, shape of `x1_values`, shape of `alpha`,
and shape of `beta` don't meet the parameter description.
RuntimeError: If `x1_shape`, shape of `x2`, shape of `x3` don't meet the parameter description.
Supported Platforms:
``Ascend`` ``CPU``