!31030 fix cootensor/csrtensor docs

Merge pull request !31030 from wangrao124/fix_docs
This commit is contained in:
i-robot 2022-03-10 11:15:02 +00:00 committed by Gitee
commit 7182149955
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 17 additions and 9 deletions

View File

@ -1,12 +1,12 @@
mindspore.COOTensor
===================
.. py:class:: mindspore.COOTensor(indices=None, values=None, shape=None)
.. py:class:: mindspore.COOTensor(indices=None, values=None, shape=None, coo_tensor=None)
用来表示某一张量在给定索引上非零元素的集合,其中索引(indices)指示了每一个非零元素的位置。
.. note::
- 这是一个实验特性在未来可能会发生API的变化。
这是一个实验特性在未来可能会发生API的变化。
**参数:**
@ -15,7 +15,7 @@ mindspore.COOTensor
- **shape** (tuple(int)) - 形状为ndims的整数元组用来指定稀疏矩阵的稠密形状。
- **coo_tensor** (COOTensor) - COOTensor对象用来初始化新的COOTensor。
**输出**
**返回**
COOTensor`indices``values``shape` 组成。

View File

@ -1,13 +1,13 @@
mindspore.CSRTensor
===================
.. py:class:: mindspore.CSRTensor(indptr=None, indices=None, values=None, shape=None)
.. py:class:: mindspore.CSRTensor(indptr=None, indices=None, values=None, shape=None, csr_tensor=None)
用来表示某一张量在给定索引上非零元素的集合,其中行索引由 `indptr` 表示,列索引由 `indices`
表示,非零值由 `values` 表示。
.. note::
- 这是一个实验特性在未来可能会发生API的变化。
这是一个实验特性在未来可能会发生API的变化。
**参数:**

View File

@ -37,7 +37,7 @@ ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<Abstra
if (type->isa<TensorType>()) {
const std::set<TypePtr> valid_types = {kTensorType};
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", type, valid_types, op_name);
} else {
} else if (type->isa<CSRTensorType>() || type->isa<COOTensorType>()) {
const std::set<TypePtr> valid_types = {kCSRTensorType, kCOOTensorType};
return CheckAndConvertUtils::CheckSparseTensorTypeValid("input_x", type, valid_types, op_name);
}

View File

@ -2563,7 +2563,7 @@ class COOTensor(COOTensor_):
Return a copy of the COOTensor, cast its values to a specified type.
Args:
dtype (class:`mindspore.dtype`): Designated tensor dtype.
dtype (:class:`mindspore.dtype`): Designated tensor dtype.
Returns:
COOTensor.
@ -2719,7 +2719,15 @@ class CSRTensor(CSRTensor_):
return COOTensor(coo_indices, self.values, self.shape)
def to_dense(self):
"""Return a dense Tensor."""
"""
Converts CSRTensor to Dense Tensor.
Returns:
Tensor.
Supported Platforms:
``GPU`` ``CPU``
"""
coo_tensor = self.to_coo()
return coo_tensor.to_dense()
@ -2728,7 +2736,7 @@ class CSRTensor(CSRTensor_):
Return a copy of the CSRTensor, cast its values to a specified type.
Args:
dtype (class:`mindspore.dtype`): Designated tensor dtype.
dtype (:class:`mindspore.dtype`): Designated tensor dtype.
Returns:
CSRTensor.