forked from mindspore-Ecosystem/mindspore
!31869 fix COOTensor/CSRTenosr checks and docs
Merge pull request !31869 from wangrao124/pr_coo_csr
This commit is contained in:
commit
d0dc332656
|
@ -5,9 +5,9 @@ mindspore.COOTensor
|
|||
|
||||
用来表示某一张量在给定索引上非零元素的集合,其中索引(indices)指示了每一个非零元素的位置。
|
||||
|
||||
对一个稠密Tensor `dense` 来说,它对应的COOTensor(indices, values, dense_shape),满足 `dense[indices[i]] = values[i]` 。
|
||||
对一个稠密Tensor `dense` 来说,它对应的COOTensor(indices, values, shape),满足 `dense[indices[i]] = values[i]` 。
|
||||
|
||||
如果 `indices` 是[[0, 1], [1, 2]], `values` 是[1, 2], `dense_shape` 是(3, 4),那么它对应的稠密Tensor如下:
|
||||
如果 `indices` 是[[0, 1], [1, 2]], `values` 是[1, 2], `shape` 是(3, 4),那么它对应的稠密Tensor如下:
|
||||
|
||||
.. code-block::
|
||||
|
||||
|
|
|
@ -389,7 +389,8 @@ AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
constexpr int64_t kDimTwo = 2;
|
||||
if (indices_shp[kIndexOne] != kDimTwo) {
|
||||
MS_EXCEPTION(ValueError) << "COOTensor only support " << kDimTwo << " dimensions, but got " << indices_shp[1];
|
||||
MS_EXCEPTION(ValueError) << "Indices must be a 2 dimensional tensor, and the second dimension must be 2, but got "
|
||||
<< indices_shp[kIndexOne];
|
||||
}
|
||||
|
||||
for (const auto &elem_type : dense_shape->ElementsType()) {
|
||||
|
@ -545,11 +546,14 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
|
|||
int64_t axis_value = GetValue<int64_t>(axis->BuildValue());
|
||||
int64_t dim = static_cast<int64_t>(sparse_shape.size());
|
||||
if (axis_value < -dim || axis_value >= dim) {
|
||||
MS_LOG(EXCEPTION) << "axis should be in [" << -dim << ", " << dim << "). But got axis = " << axis_value;
|
||||
MS_EXCEPTION(ValueError) << "axis should be in [" << -dim << ", " << dim << "). But got axis = " << axis_value;
|
||||
}
|
||||
if (axis_value >= -dim && axis_value < 0) {
|
||||
axis_value += dim;
|
||||
}
|
||||
if (axis_value != 1) {
|
||||
MS_EXCEPTION(ValueError) << "axis must be 1, but got " << axis_value;
|
||||
}
|
||||
out_shape[LongToSize(axis_value)] = 1;
|
||||
primitive->set_attr(kCSRAxis, MakeValue(axis_value));
|
||||
} else {
|
||||
|
|
|
@ -2427,10 +2427,10 @@ class COOTensor(COOTensor_):
|
|||
"""
|
||||
A sparse representation of a set of nonzero elements from a tensor at given indices.
|
||||
|
||||
For a tensor dense, its COOTensor(indices, values, dense_shape) has
|
||||
For a tensor dense, its COOTensor(indices, values, shape) has
|
||||
`dense[indices[i]] = values[i]`.
|
||||
|
||||
For example, if indices is [[0, 1], [1, 2]], values is [1, 2], dense_shape is
|
||||
For example, if indices is [[0, 1], [1, 2]], values is [1, 2], shape is
|
||||
(3, 4), then the dense representation of the sparse tensor will be:
|
||||
|
||||
.. code-block::
|
||||
|
@ -2840,7 +2840,7 @@ class CSRTensor(CSRTensor_):
|
|||
axis (int) - The dimensions to reduce.
|
||||
|
||||
Returns:
|
||||
Tensor, the dtype is the same as `sparse_tensor.values`.
|
||||
Tensor, the dtype is the same as `CSRTensor.values`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
|
Loading…
Reference in New Issue