!31030 fix cootensor/csrtensor docs
Merge pull request !31030 from wangrao124/fix_docs
This commit is contained in:
commit
7182149955
|
@ -1,12 +1,12 @@
|
|||
mindspore.COOTensor
|
||||
===================
|
||||
|
||||
.. py:class:: mindspore.COOTensor(indices=None, values=None, shape=None)
|
||||
.. py:class:: mindspore.COOTensor(indices=None, values=None, shape=None, coo_tensor=None)
|
||||
|
||||
用来表示某一张量在给定索引上非零元素的集合,其中索引(indices)指示了每一个非零元素的位置。
|
||||
|
||||
.. note::
|
||||
- 这是一个实验特性,在未来可能会发生API的变化。
|
||||
这是一个实验特性,在未来可能会发生API的变化。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
@ -15,7 +15,7 @@ mindspore.COOTensor
|
|||
- **shape** (tuple(int)) - 形状为ndims的整数元组,用来指定稀疏矩阵的稠密形状。
|
||||
- **coo_tensor** (COOTensor) - COOTensor对象,用来初始化新的COOTensor。
|
||||
|
||||
**输出:**
|
||||
**返回:**
|
||||
|
||||
COOTensor,由 `indices` 、 `values` 和 `shape` 组成。
|
||||
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
mindspore.CSRTensor
|
||||
===================
|
||||
|
||||
.. py:class:: mindspore.CSRTensor(indptr=None, indices=None, values=None, shape=None)
|
||||
.. py:class:: mindspore.CSRTensor(indptr=None, indices=None, values=None, shape=None, csr_tensor=None)
|
||||
|
||||
用来表示某一张量在给定索引上非零元素的集合,其中行索引由 `indptr` 表示,列索引由 `indices`
|
||||
表示,非零值由 `values` 表示。
|
||||
|
||||
.. note::
|
||||
- 这是一个实验特性,在未来可能会发生API的变化。
|
||||
这是一个实验特性,在未来可能会发生API的变化。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<Abstra
|
|||
if (type->isa<TensorType>()) {
|
||||
const std::set<TypePtr> valid_types = {kTensorType};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", type, valid_types, op_name);
|
||||
} else {
|
||||
} else if (type->isa<CSRTensorType>() || type->isa<COOTensorType>()) {
|
||||
const std::set<TypePtr> valid_types = {kCSRTensorType, kCOOTensorType};
|
||||
return CheckAndConvertUtils::CheckSparseTensorTypeValid("input_x", type, valid_types, op_name);
|
||||
}
|
||||
|
|
|
@ -2563,7 +2563,7 @@ class COOTensor(COOTensor_):
|
|||
Return a copy of the COOTensor, cast its values to a specified type.
|
||||
|
||||
Args:
|
||||
dtype (class:`mindspore.dtype`): Designated tensor dtype.
|
||||
dtype (:class:`mindspore.dtype`): Designated tensor dtype.
|
||||
|
||||
Returns:
|
||||
COOTensor.
|
||||
|
@ -2719,7 +2719,15 @@ class CSRTensor(CSRTensor_):
|
|||
return COOTensor(coo_indices, self.values, self.shape)
|
||||
|
||||
def to_dense(self):
|
||||
"""Return a dense Tensor."""
|
||||
"""
|
||||
Converts CSRTensor to Dense Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
"""
|
||||
coo_tensor = self.to_coo()
|
||||
return coo_tensor.to_dense()
|
||||
|
||||
|
@ -2728,7 +2736,7 @@ class CSRTensor(CSRTensor_):
|
|||
Return a copy of the CSRTensor, cast its values to a specified type.
|
||||
|
||||
Args:
|
||||
dtype (class:`mindspore.dtype`): Designated tensor dtype.
|
||||
dtype (:class:`mindspore.dtype`): Designated tensor dtype.
|
||||
|
||||
Returns:
|
||||
CSRTensor.
|
||||
|
|
Loading…
Reference in New Issue