!31869 fix COOTensor/CSRTenosr checks and docs

Merge pull request !31869 from wangrao124/pr_coo_csr
This commit is contained in:
i-robot 2022-03-25 07:17:09 +00:00 committed by Gitee
commit d0dc332656
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 11 additions and 7 deletions

View File

@ -5,9 +5,9 @@ mindspore.COOTensor
用来表示某一张量在给定索引上非零元素的集合,其中索引(indices)指示了每一个非零元素的位置。
对一个稠密Tensor `dense` 来说它对应的COOTensor(indices, values, dense_shape),满足 `dense[indices[i]] = values[i]`
对一个稠密Tensor `dense` 来说它对应的COOTensor(indices, values, shape),满足 `dense[indices[i]] = values[i]`
如果 `indices` 是[[0, 1], [1, 2]] `values` 是[1, 2] `dense_shape` 是(3, 4)那么它对应的稠密Tensor如下
如果 `indices` 是[[0, 1], [1, 2]] `values` 是[1, 2] `shape` 是(3, 4)那么它对应的稠密Tensor如下
.. code-block::

View File

@ -389,7 +389,8 @@ AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const Primitiv
}
constexpr int64_t kDimTwo = 2;
if (indices_shp[kIndexOne] != kDimTwo) {
MS_EXCEPTION(ValueError) << "COOTensor only support " << kDimTwo << " dimensions, but got " << indices_shp[1];
MS_EXCEPTION(ValueError) << "Indices must be a 2 dimensional tensor, and the second dimension must be 2, but got "
<< indices_shp[kIndexOne];
}
for (const auto &elem_type : dense_shape->ElementsType()) {
@ -545,11 +546,14 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
int64_t axis_value = GetValue<int64_t>(axis->BuildValue());
int64_t dim = static_cast<int64_t>(sparse_shape.size());
if (axis_value < -dim || axis_value >= dim) {
MS_LOG(EXCEPTION) << "axis should be in [" << -dim << ", " << dim << "). But got axis = " << axis_value;
MS_EXCEPTION(ValueError) << "axis should be in [" << -dim << ", " << dim << "). But got axis = " << axis_value;
}
if (axis_value >= -dim && axis_value < 0) {
axis_value += dim;
}
if (axis_value != 1) {
MS_EXCEPTION(ValueError) << "axis must be 1, but got " << axis_value;
}
out_shape[LongToSize(axis_value)] = 1;
primitive->set_attr(kCSRAxis, MakeValue(axis_value));
} else {

View File

@ -2427,10 +2427,10 @@ class COOTensor(COOTensor_):
"""
A sparse representation of a set of nonzero elements from a tensor at given indices.
For a tensor dense, its COOTensor(indices, values, dense_shape) has
For a tensor dense, its COOTensor(indices, values, shape) has
`dense[indices[i]] = values[i]`.
For example, if indices is [[0, 1], [1, 2]], values is [1, 2], dense_shape is
For example, if indices is [[0, 1], [1, 2]], values is [1, 2], shape is
(3, 4), then the dense representation of the sparse tensor will be:
.. code-block::
@ -2840,7 +2840,7 @@ class CSRTensor(CSRTensor_):
axis (int) - The dimensions to reduce.
Returns:
Tensor, the dtype is the same as `sparse_tensor.values`.
Tensor, the dtype is the same as `CSRTensor.values`.
Supported Platforms:
``GPU`` ``CPU``