fix csr/coo

This commit is contained in:
wangrao124 2022-03-16 17:54:28 +08:00
parent a22321220a
commit 5e5415f617
6 changed files with 29 additions and 16 deletions

View File

@ -16,11 +16,11 @@ mindspore.COOTensor
[0, 0, 0, 0]]
.. note::
这是一个实验特性在未来可能会发生API的变化。
这是一个实验特性在未来可能会发生API的变化。目前COOTensor中相同索引的值不会进行合并。
**参数:**
- **indices** (Tensor) - 形状为 `[N, ndims]` 的二维整数张量其中N和ndims分别表示稀疏张量中 `values` 的数量和COOTensor维度的数量。
- **indices** (Tensor) - 形状为 `[N, ndims]` 的二维整数张量其中N和ndims分别表示稀疏张量中 `values` 的数量和COOTensor维度的数量。 请确保indices的值在所给shape范围内。
- **values** (Tensor) - 形状为 `[N]` 的一维张量,用来给 `indices` 中的每个元素提供数值。
- **shape** (tuple(int)) - 形状为ndims的整数元组用来指定稀疏矩阵的稠密形状。
- **coo_tensor** (COOTensor) - COOTensor对象用来初始化新的COOTensor。
@ -35,7 +35,7 @@ mindspore.COOTensor
**返回:**
CSRTensor。
COOTensor。
.. py:method:: astype(dtype)

View File

@ -11,15 +11,15 @@ mindspore.CSRTensor
**参数:**
- **indptr** (Tensor) - 形状为 `[M]` 的一维整数张量其中M等于 `shape[0] + 1` , 表示每行非零元素的在 `values` 中存储的起止位置。
- **indices** (Tensor) - 形状为 `[N]` 的一维整数张量其中N等于非零元素数量表示每个元素的列索引值。
- **values** (Tensor) - 形状为 `[N]` 的一维张量,用来表示索引对应的数值。
- **shape** (tuple(int)) - 形状为ndims的整数元组用来指定稀疏矩阵的稠密形状。
- **csr_tensor** (CSRTensor) - CSRTensor对象用来初始化新的CSRTensor。
- **indptr** (Tensor) - 形状为 `[M]` 的一维整数张量其中M等于 `shape[0] + 1` , 表示每行非零元素的在 `values` 中存储的起止位置。默认值None。支持的数据类型为 `int16` `int32``int64`
- **indices** (Tensor) - 形状为 `[N]` 的一维整数张量其中N等于非零元素数量表示每个元素的列索引值。默认值None。支持的数据类型为 `int16` `int32``int64`
- **values** (Tensor) - 形状为 `[N]` 的一维张量,用来表示索引对应的数值。默认值None。
- **shape** (tuple(int)) - 形状为ndims的整数元组用来指定稀疏矩阵的稠密形状。目前只支持2维CSRTensor所以 `shape` 长度只能为2。`shape[0]` 表示行数,因此必须和 `indptr[0] - 1` 值相等。默认值None。
- **csr_tensor** (CSRTensor) - CSRTensor对象用来初始化新的CSRTensor。默认值None。
**输出:**
CSRTensor`indptr``indices``values``shape` 组成
CSRTensor稠密形状取决于传入的 `shape` ,数据类型由 `values` 决定
.. py:method:: abs()

View File

@ -387,6 +387,9 @@ AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const Primitiv
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
<< values_shp[0] << ", but got " << indices_shp[0];
}
if (indices_shp[1] != kSizeTwo) {
MS_EXCEPTION(ValueError) << "COOTensor only support " << kSizeTwo << " dimensions, but got " << indices_shp[1];
}
for (const auto &elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
@ -396,6 +399,9 @@ AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const Primitiv
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
if (*std::min_element(std::begin(shp), std::end(shp)) <= 0) {
MS_EXCEPTION(ValueError) << "The element of dense_shape must positive integer.";
}
ShapeVector dense_shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
[](const ValuePtr &e) -> int64_t {

View File

@ -892,6 +892,10 @@ class Validator:
raise ValueError(f"Values must be a 1-dimensional tensor, but got a {len(values_shp)} dimension tensor.")
if indices_shp[0] != values_shp[0]:
raise ValueError(f"Indices.shape must be (N, 2), where N equals to number of nonzero values in coo tensor.")
if indices_shp[1] != 2:
raise ValueError(f"Indices.shape must be (N, 2), where N equals to number of nonzero values in coo tensor.")
if min(coo_shp) <= 0:
raise ValueError(f"Dense shape must be a tuple of positive integers.")
@staticmethod
def check_coo_tensor_dtype(indices_dtype):

View File

@ -2434,11 +2434,12 @@ class COOTensor(COOTensor_):
Note:
This is an experimental feature and is subjected to change.
Currently, duplicate coordinates in the indices will not be coalesced.
Args:
indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`,
where N and ndims are the number of `values` and number of dimensions in
the COOTensor, respectively.
the COOTensor, respectively. Please make sure that the indices are in range of the given shape.
values (Tensor): A 1-D tensor of any type and shape `[N]`, which
supplies the values for each element in `indices`.
shape (tuple(int)): A integer tuple of size `ndims`,
@ -2476,9 +2477,11 @@ class COOTensor(COOTensor_):
# Init a COOTensor from indices, values and shape
else:
if not (isinstance(indices, Tensor) and isinstance(values, Tensor) and isinstance(shape, tuple)):
raise TypeError("Inputs must follow: COOTensor(indices, values, shape).")
raise TypeError("Inputs must follow: COOTensor(Tensor, Tensor, tuple).")
validator.check_coo_tensor_shape(indices.shape, values.shape, shape)
validator.check_coo_tensor_dtype(indices.dtype)
if not (indices < Tensor(shape)).all() or (indices < 0).any():
raise ValueError("All the indices should be non-negative integer and in range of the given shape!")
COOTensor_.__init__(self, indices, values, shape)
self.init_finished = True
@ -2630,8 +2633,8 @@ class CSRTensor(CSRTensor_):
stores the data for CSRTensor. Default: None.
shape (Tuple): A tuple indicates the shape of the CSRTensor, its length must
be `2`, as only 2-D CSRTensor is currently supported, and `shape[0]` must
equal to `indptr[0] - 1`, which all equal to number of rows of the CSRTensor.
csr_tensor (CSRTensor): A CSRTensor object.
equal to `indptr[0] - 1`, which all equal to number of rows of the CSRTensor. Default: None.
csr_tensor (CSRTensor): A CSRTensor object. Default: None.
Outputs:
CSRTensor, with shape defined by `shape`, and dtype inferred from `value`.
@ -2834,8 +2837,8 @@ class CSRTensor(CSRTensor_):
Examples:
>>> from mindspore import Tensor, CSRTensor
>>> from mindspore import dtype as mstype
>>> indptr = Tensor([0, 1, 2], dtype=ms.int32)
>>> indices = Tensor([0, 1], dtype=ms.int32)
>>> indptr = Tensor([0, 1, 2], dtype=mstype.int32)
>>> indices = Tensor([0, 1], dtype=mstype.int32)
>>> values = Tensor([2, 1], dtype=mstype.float32)
>>> dense_shape = (2, 4)
>>> csr_tensor = CSRTensor(indptr, indices, values, dense_shape)

View File

@ -84,7 +84,7 @@ def test_sparse_tensor_indices_dim_greater_than_dense_shape_dim():
indices = Tensor(np.array([[0, 0, 0], [0, 0, 1]], dtype=np.int32))
values = Tensor(np.array([100, 200], dtype=np.float32))
dense_shape = (2, 2)
with pytest.raises(TypeError):
with pytest.raises(ValueError):
MakeSparseTensor(dense_shape)(indices, values)