From 5e5415f6176043b9cff4be3b021022184e5f5aee Mon Sep 17 00:00:00 2001 From: wangrao124 Date: Wed, 16 Mar 2022 17:54:28 +0800 Subject: [PATCH] fix csr/coo --- .../api_python/mindspore/mindspore.COOTensor.rst | 6 +++--- .../api_python/mindspore/mindspore.CSRTensor.rst | 12 ++++++------ mindspore/core/abstract/prim_others.cc | 6 ++++++ mindspore/python/mindspore/_checkparam.py | 4 ++++ mindspore/python/mindspore/common/tensor.py | 15 +++++++++------ tests/ut/python/ir/test_sparse_tensor.py | 2 +- 6 files changed, 29 insertions(+), 16 deletions(-) diff --git a/docs/api/api_python/mindspore/mindspore.COOTensor.rst b/docs/api/api_python/mindspore/mindspore.COOTensor.rst index 19956c1864b..dba5f5b1b34 100644 --- a/docs/api/api_python/mindspore/mindspore.COOTensor.rst +++ b/docs/api/api_python/mindspore/mindspore.COOTensor.rst @@ -16,11 +16,11 @@ mindspore.COOTensor [0, 0, 0, 0]] .. note:: - 这是一个实验特性,在未来可能会发生API的变化。 + 这是一个实验特性,在未来可能会发生API的变化。目前COOTensor中相同索引的值不会进行合并。 **参数:** - - **indices** (Tensor) - 形状为 `[N, ndims]` 的二维整数张量,其中N和ndims分别表示稀疏张量中 `values` 的数量和COOTensor维度的数量。 + - **indices** (Tensor) - 形状为 `[N, ndims]` 的二维整数张量,其中N和ndims分别表示稀疏张量中 `values` 的数量和COOTensor维度的数量。 请确保indices的值在所给shape范围内。 - **values** (Tensor) - 形状为 `[N]` 的一维张量,用来给 `indices` 中的每个元素提供数值。 - **shape** (tuple(int)) - 形状为ndims的整数元组,用来指定稀疏矩阵的稠密形状。 - **coo_tensor** (COOTensor) - COOTensor对象,用来初始化新的COOTensor。 @@ -35,7 +35,7 @@ mindspore.COOTensor **返回:** - CSRTensor。 + COOTensor。 .. py:method:: astype(dtype) diff --git a/docs/api/api_python/mindspore/mindspore.CSRTensor.rst b/docs/api/api_python/mindspore/mindspore.CSRTensor.rst index 1670a275d3e..27489da509a 100644 --- a/docs/api/api_python/mindspore/mindspore.CSRTensor.rst +++ b/docs/api/api_python/mindspore/mindspore.CSRTensor.rst @@ -11,15 +11,15 @@ mindspore.CSRTensor **参数:** - - **indptr** (Tensor) - 形状为 `[M]` 的一维整数张量,其中M等于 `shape[0] + 1` , 表示每行非零元素的在 `values` 中存储的起止位置。 - - **indices** (Tensor) - 形状为 `[N]` 的一维整数张量,其中N等于非零元素数量,表示每个元素的列索引值。 - - **values** (Tensor) - 形状为 `[N]` 的一维张量,用来表示索引对应的数值。 - - **shape** (tuple(int)) - 形状为ndims的整数元组,用来指定稀疏矩阵的稠密形状。 - - **csr_tensor** (CSRTensor) - CSRTensor对象,用来初始化新的CSRTensor。 + - **indptr** (Tensor) - 形状为 `[M]` 的一维整数张量,其中M等于 `shape[0] + 1` , 表示每行非零元素的在 `values` 中存储的起止位置。默认值:None。支持的数据类型为 `int16` , `int32` 和 `int64` 。 + - **indices** (Tensor) - 形状为 `[N]` 的一维整数张量,其中N等于非零元素数量,表示每个元素的列索引值。默认值:None。支持的数据类型为 `int16` , `int32` 和 `int64` 。 + - **values** (Tensor) - 形状为 `[N]` 的一维张量,用来表示索引对应的数值。默认值:None。 + - **shape** (tuple(int)) - 形状为ndims的整数元组,用来指定稀疏矩阵的稠密形状。目前只支持2维CSRTensor,所以 `shape` 长度只能为2。`shape[0]` 表示行数,因此必须和 `indptr[0] - 1` 值相等。默认值:None。 + - **csr_tensor** (CSRTensor) - CSRTensor对象,用来初始化新的CSRTensor。默认值:None。 **输出:** - CSRTensor,由 `indptr` 、 `indices` 、 `values` 和 `shape` 组成。 + CSRTensor,稠密形状取决于传入的 `shape` ,数据类型由 `values` 决定。 .. py:method:: abs() diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 1562134b4a7..84f548c8451 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -387,6 +387,9 @@ AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const Primitiv MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values " << values_shp[0] << ", but got " << indices_shp[0]; } + if (indices_shp[1] != kSizeTwo) { + MS_EXCEPTION(ValueError) << "COOTensor only support " << kSizeTwo << " dimensions, but got " << indices_shp[1]; + } for (const auto &elem_type : dense_shape->ElementsType()) { if (!elem_type->isa()) { @@ -396,6 +399,9 @@ AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const Primitiv auto dense_shape_value = dense_shape->BuildValue()->cast(); MS_EXCEPTION_IF_NULL(dense_shape_value); auto shp = dense_shape_value->value(); + if (*std::min_element(std::begin(shp), std::end(shp)) <= 0) { + MS_EXCEPTION(ValueError) << "The element of dense_shape must positive integer."; + } ShapeVector dense_shape_vec; (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), [](const ValuePtr &e) -> int64_t { diff --git a/mindspore/python/mindspore/_checkparam.py b/mindspore/python/mindspore/_checkparam.py index 3f6a47daae7..6e67ccfa76a 100644 --- a/mindspore/python/mindspore/_checkparam.py +++ b/mindspore/python/mindspore/_checkparam.py @@ -892,6 +892,10 @@ class Validator: raise ValueError(f"Values must be a 1-dimensional tensor, but got a {len(values_shp)} dimension tensor.") if indices_shp[0] != values_shp[0]: raise ValueError(f"Indices.shape must be (N, 2), where N equals to number of nonzero values in coo tensor.") + if indices_shp[1] != 2: + raise ValueError(f"Indices.shape must be (N, 2), where N equals to number of nonzero values in coo tensor.") + if min(coo_shp) <= 0: + raise ValueError(f"Dense shape must be a tuple of positive integers.") @staticmethod def check_coo_tensor_dtype(indices_dtype): diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 26bdf750e85..bb017018738 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -2434,11 +2434,12 @@ class COOTensor(COOTensor_): Note: This is an experimental feature and is subjected to change. + Currently, duplicate coordinates in the indices will not be coalesced. Args: indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`, where N and ndims are the number of `values` and number of dimensions in - the COOTensor, respectively. + the COOTensor, respectively. Please make sure that the indices are in range of the given shape. values (Tensor): A 1-D tensor of any type and shape `[N]`, which supplies the values for each element in `indices`. shape (tuple(int)): A integer tuple of size `ndims`, @@ -2476,9 +2477,11 @@ class COOTensor(COOTensor_): # Init a COOTensor from indices, values and shape else: if not (isinstance(indices, Tensor) and isinstance(values, Tensor) and isinstance(shape, tuple)): - raise TypeError("Inputs must follow: COOTensor(indices, values, shape).") + raise TypeError("Inputs must follow: COOTensor(Tensor, Tensor, tuple).") validator.check_coo_tensor_shape(indices.shape, values.shape, shape) validator.check_coo_tensor_dtype(indices.dtype) + if not (indices < Tensor(shape)).all() or (indices < 0).any(): + raise ValueError("All the indices should be non-negative integer and in range of the given shape!") COOTensor_.__init__(self, indices, values, shape) self.init_finished = True @@ -2630,8 +2633,8 @@ class CSRTensor(CSRTensor_): stores the data for CSRTensor. Default: None. shape (Tuple): A tuple indicates the shape of the CSRTensor, its length must be `2`, as only 2-D CSRTensor is currently supported, and `shape[0]` must - equal to `indptr[0] - 1`, which all equal to number of rows of the CSRTensor. - csr_tensor (CSRTensor): A CSRTensor object. + equal to `indptr[0] - 1`, which all equal to number of rows of the CSRTensor. Default: None. + csr_tensor (CSRTensor): A CSRTensor object. Default: None. Outputs: CSRTensor, with shape defined by `shape`, and dtype inferred from `value`. @@ -2834,8 +2837,8 @@ class CSRTensor(CSRTensor_): Examples: >>> from mindspore import Tensor, CSRTensor >>> from mindspore import dtype as mstype - >>> indptr = Tensor([0, 1, 2], dtype=ms.int32) - >>> indices = Tensor([0, 1], dtype=ms.int32) + >>> indptr = Tensor([0, 1, 2], dtype=mstype.int32) + >>> indices = Tensor([0, 1], dtype=mstype.int32) >>> values = Tensor([2, 1], dtype=mstype.float32) >>> dense_shape = (2, 4) >>> csr_tensor = CSRTensor(indptr, indices, values, dense_shape) diff --git a/tests/ut/python/ir/test_sparse_tensor.py b/tests/ut/python/ir/test_sparse_tensor.py index 0ccd530ee28..e558f4cea87 100644 --- a/tests/ut/python/ir/test_sparse_tensor.py +++ b/tests/ut/python/ir/test_sparse_tensor.py @@ -84,7 +84,7 @@ def test_sparse_tensor_indices_dim_greater_than_dense_shape_dim(): indices = Tensor(np.array([[0, 0, 0], [0, 0, 1]], dtype=np.int32)) values = Tensor(np.array([100, 200], dtype=np.float32)) dense_shape = (2, 2) - with pytest.raises(TypeError): + with pytest.raises(ValueError): MakeSparseTensor(dense_shape)(indices, values)