From 61e102752719e4b46cec2fbc60e39986f13fdea7 Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Thu, 21 Apr 2022 09:50:59 +0800 Subject: [PATCH] support rowtensor in pynative --- mindspore/ccsrc/pipeline/jit/resource.cc | 5 +- mindspore/ccsrc/pybind_api/ir/tensor_py.cc | 38 ++++-- mindspore/ccsrc/pybind_api/ir/tensor_py.h | 6 + mindspore/core/abstract/ops/infer_functions.h | 2 - mindspore/core/abstract/ops/prim_others.cc | 6 +- mindspore/core/abstract/ops/prim_statement.cc | 17 --- .../core/abstract/ops/primitive_infer_map.cc | 1 - mindspore/core/ir/meta_tensor.cc | 4 - mindspore/core/ir/meta_tensor.h | 9 +- mindspore/core/ir/tensor.cc | 31 ++++- mindspore/core/ir/tensor.h | 67 +++++++++-- mindspore/core/ops/core_ops.h | 2 +- mindspore/python/mindspore/_checkparam.py | 4 +- mindspore/python/mindspore/common/api.py | 5 +- mindspore/python/mindspore/common/tensor.py | 57 +++++---- tests/st/sparse/test_row.py | 108 ++++++++++++++++++ 16 files changed, 285 insertions(+), 77 deletions(-) create mode 100644 tests/st/sparse/test_row.py diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index 52b8823e4ea..64302941c20 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -39,8 +39,9 @@ BuiltInTypeMap &GetMethodMap() { {"__bool__", std::string("none_bool")} // C.none_bool }}, {kObjectTypeFunction, - {{"__bool__", std::string("func_bool")}, // C.str_bool - {"__is_csr_func__", prim::kPrimIsCSRFunc}}}, + { + {"__bool__", std::string("func_bool")} // C.str_bool + }}, {kNumberTypeBool, { {"__and__", prim::kPrimBoolAnd}, // P.bool_and diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index 8cb1c8dcc3c..8ee3a0e8388 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -653,8 +653,9 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { })); })); -py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) { - auto &shape = csr_tensor.shape(); +template +py::tuple GetSparseTensorShape(const T &sparse_tensor) { + auto &shape = sparse_tensor.shape(); py::tuple dims(shape.size()); for (size_t i = 0; i < dims.size(); ++i) { dims[i] = py::int_(shape[i]); @@ -662,6 +663,8 @@ py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) { return dims; } +py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) { return GetSparseTensorShape(csr_tensor); } + REGISTER_PYBIND_DEFINE( CSRTensor, ([](const py::module *m) { // Define python CSRTensor class. @@ -682,14 +685,7 @@ REGISTER_PYBIND_DEFINE( .def("__repr__", &CSRTensor::ToString); })); -py::tuple COOTensorPy::GetPyTupleShape(const COOTensor &coo_tensor) { - auto &shape = coo_tensor.shape(); - py::tuple dims(shape.size()); - for (size_t i = 0; i < dims.size(); ++i) { - dims[i] = py::int_(shape[i]); - } - return dims; -} +py::tuple COOTensorPy::GetPyTupleShape(const COOTensor &coo_tensor) { return GetSparseTensorShape(coo_tensor); } REGISTER_PYBIND_DEFINE(COOTensor, ([](const py::module *m) { // Define python COOTensor class. @@ -710,5 +706,27 @@ REGISTER_PYBIND_DEFINE(COOTensor, ([](const py::module *m) { .def("__str__", &COOTensor::ToString) .def("__repr__", &COOTensor::ToString); })); + +py::tuple RowTensorPy::GetPyTupleShape(const RowTensor &row_tensor) { return GetSparseTensorShape(row_tensor); } + +REGISTER_PYBIND_DEFINE(RowTensor, ([](const py::module *m) { + // Define python RowTensor class. + (void)py::class_>(*m, "RowTensor") + .def(py::init([](const Tensor &indices, const Tensor &values, const py::tuple &shape) { + return std::make_shared(std::make_shared(indices), + std::make_shared(values), + GetShapeFromTuple(shape)); + }), + py::arg("indices"), py::arg("values"), py::arg("shape")) + .def(py::init( + [](const RowTensor &row_tensor) { return std::make_shared(row_tensor); }), + py::arg("input")) + .def_property_readonly("_shape", RowTensorPy::GetPyTupleShape) + .def_property_readonly("_dtype", &RowTensor::Dtype) + .def_property_readonly("_indices", &RowTensor::GetIndices) + .def_property_readonly("_values", &RowTensor::GetValues) + .def("__str__", &RowTensor::ToString) + .def("__repr__", &RowTensor::ToString); + })); } // namespace tensor } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.h b/mindspore/ccsrc/pybind_api/ir/tensor_py.h index fa5bd2a97fb..cd582723296 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.h +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.h @@ -130,6 +130,12 @@ class COOTensorPy { public: static py::tuple GetPyTupleShape(const COOTensor &coo_tensor); }; + +// RowTensor python wrapper and adapter class. +class RowTensorPy { + public: + static py::tuple GetPyTupleShape(const RowTensor &row_tensor); +}; } // namespace tensor } // namespace mindspore diff --git a/mindspore/core/abstract/ops/infer_functions.h b/mindspore/core/abstract/ops/infer_functions.h index 9a86e83dbeb..c0bcb9a6d72 100644 --- a/mindspore/core/abstract/ops/infer_functions.h +++ b/mindspore/core/abstract/ops/infer_functions.h @@ -169,8 +169,6 @@ AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr & const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/ops/prim_others.cc b/mindspore/core/abstract/ops/prim_others.cc index 05a390a5b31..20491bc7a28 100644 --- a/mindspore/core/abstract/ops/prim_others.cc +++ b/mindspore/core/abstract/ops/prim_others.cc @@ -477,8 +477,8 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr auto sparse_shape = sparse->shape()->shape(); auto dense_shape = dense->shape()->shape(); if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) { - MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs!" - << "but sparse tensor has " << sparse_shape.size() << " dimensions, " + MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs! " + << "But csr tensor has " << sparse_shape.size() << " dimensions, " << "and dense tensor has " << dense_shape.size() << " dimensions, "; } if (dense_shape[kIndexZero] != sparse_shape[kIndexOne] || dense_shape[kIndexOne] != 1) { @@ -541,7 +541,7 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive primitive->set_attr(kCSRAxis, MakeValue(axis_value)); } else { MS_EXCEPTION(TypeError) << "For CSRReduceSum, `axis` should be int32 or int64, but got " - << axis->BuildValue()->ToString(); + << axis->BuildType()->ToString(); } MS_EXCEPTION_IF_NULL(sparse->values()->element()); diff --git a/mindspore/core/abstract/ops/prim_statement.cc b/mindspore/core/abstract/ops/prim_statement.cc index f463d1930d1..5d10b3f0022 100644 --- a/mindspore/core/abstract/ops/prim_statement.cc +++ b/mindspore/core/abstract/ops/prim_statement.cc @@ -247,22 +247,5 @@ AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePt ValuePtr v = args_spec_list[0]->BuildValue(); return std::make_shared(!v->isa()); } - -AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Statement: x.__is_csr_func__() - // Inputs: x - auto func = CheckArg(primitive->name(), args_spec_list, 0); - MS_EXCEPTION_IF_NULL(func); - auto prim_func = dyn_cast(func); - if (prim_func != nullptr) { - PrimitivePtr prim = prim_func->prim(); - std::string name = prim->name(); - if (name == "S-Prim-MakeCSRTensor") { - return std::make_shared(1); - } - } - return std::make_shared(0); -} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/ops/primitive_infer_map.cc b/mindspore/core/abstract/ops/primitive_infer_map.cc index 9229a3ae4d2..b5e9a413303 100644 --- a/mindspore/core/abstract/ops/primitive_infer_map.cc +++ b/mindspore/core/abstract/ops/primitive_infer_map.cc @@ -154,7 +154,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimInDict, R{InferImplInDict, nullptr, true}}, {prim::kPrimNotInDict, R{InferImplNotInDict, nullptr, true}}, {prim::kPrimIsConsant, R{InferImplIsConstant, nullptr, true}}, - {prim::kPrimIsCSRFunc, R{InferImplIsCSRFunc, nullptr, true}}, // Maths {prim::kPrimMatMul, R{InferImplMatMul, nullptr, true}}, {prim::kPrimBatchMatMul, R{InferImplBatchMatMul, nullptr, true}}, diff --git a/mindspore/core/ir/meta_tensor.cc b/mindspore/core/ir/meta_tensor.cc index d5eb1901926..80a9e98d226 100644 --- a/mindspore/core/ir/meta_tensor.cc +++ b/mindspore/core/ir/meta_tensor.cc @@ -110,10 +110,6 @@ MetaSparseTensor::MetaSparseTensor(TypeId data_type, const ShapeVector &shape) : MetaSparseTensor::MetaSparseTensor(const MetaSparseTensor &meta_sparse_tensor) : Value(meta_sparse_tensor), data_type_(meta_sparse_tensor.data_type()), shape_(meta_sparse_tensor.shape()) {} -bool MetaSparseTensor::operator==(const MetaSparseTensor &meta_sparse_tensor) const { - return data_type_ == meta_sparse_tensor.data_type() && shape_ == meta_sparse_tensor.shape(); -} - TypePtr MetaSparseTensor::Dtype() const { return TypeIdToType(data_type_); } } // namespace tensor } // namespace mindspore diff --git a/mindspore/core/ir/meta_tensor.h b/mindspore/core/ir/meta_tensor.h index 8f6fcd13ff7..939f15d1897 100644 --- a/mindspore/core/ir/meta_tensor.h +++ b/mindspore/core/ir/meta_tensor.h @@ -269,7 +269,9 @@ class MS_CORE_API MetaSparseTensor : public Value { /// /// \param[in] meta_sparse_tensor The MetaSparseTensor object to be compared. /// \return Return true if having same type and shape, otherwise return false. - virtual bool operator==(const MetaSparseTensor &meta_sparse_tensor) const; + virtual bool operator==(const MetaSparseTensor &meta_sparse_tensor) const { + return data_type_ == meta_sparse_tensor.data_type() && shape_ == meta_sparse_tensor.shape(); + } /// \brief Get the data type of the sparse tensor. /// All the types are defined in "ir/dtype.h". @@ -297,6 +299,11 @@ class MS_CORE_API MetaSparseTensor : public Value { /// \param[in] shape The shape of the tensor. void set_shape(const ShapeVector &shape) { this->shape_ = shape; } + /// \brief Get display information of this Tensor. + /// + /// \return The display information of this Tensor. + virtual std::string ToString() const = 0; + protected: // Data type of the sparsetensor. TypeId data_type_; diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index c414735740d..c7b8885bcdf 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -930,8 +930,6 @@ TensorPtrList Tensor::GetFlattenedTensors(const TensorPtrList &tensors) { CSRTensor::CSRTensor(const TensorPtr indptr, const TensorPtr indices, const TensorPtr values, const ShapeVector &shape) : MetaSparseTensor(values->data_type(), shape), indptr_(indptr), indices_(indices), values_(values) {} -bool CSRTensor::operator==(const CSRTensor &csr_tensor) const { return (&csr_tensor == this); } - std::string CSRTensor::ToString() const { std::ostringstream buf; MS_EXCEPTION_IF_NULL(values_); @@ -1009,5 +1007,34 @@ abstract::AbstractBasePtr COOTensor::ToAbstract() { return abs_sparse_tensor; } + +std::string RowTensor::ToString() const { + std::ostringstream buf; + MS_EXCEPTION_IF_NULL(indices_); + MS_EXCEPTION_IF_NULL(values_); + auto dtype = values_->Dtype(); + buf << "RowTensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() + << ", indices=" << indices_->ToString() << ", values=" << values_->ToString() << ")"; + return buf.str(); +} + +abstract::AbstractBasePtr RowTensor::ToAbstract() { + auto dtype = values_->Dtype(); + if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) { + MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << "."; + } + auto abs_sparse_tensor = std::make_shared(dtype, shape_); + + abs_sparse_tensor->set_indices(indices_->ToAbstract()->cast()); + abs_sparse_tensor->set_values(values_->ToAbstract()->cast()); + + std::vector abstract_shape; + (void)std::transform( + shape_.begin(), shape_.end(), std::back_inserter(abstract_shape), + [](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared(shp); }); + abs_sparse_tensor->set_dense_shape(std::make_shared(abstract_shape)); + + return abs_sparse_tensor; +} } // namespace tensor } // namespace mindspore diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index b4f39629e86..e71edd9fea7 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -652,11 +652,11 @@ class MS_CORE_API CSRTensor : public MetaSparseTensor { /// \return [TensorPtr] The values. TensorPtr GetValues() { return values_; } - /// \brief Compare two tensor objects to see if they have same data type, shape and data address. + /// \brief Compare two csrtensor objects to see if they have same data address. /// - /// \param[in] tensor The Tensor object to be compared. - /// \return True if having same type, shape and data address, otherwise false. - bool operator==(const CSRTensor &csr_tensor) const; + /// \param[in] csr_tensor The csrtensor object to be compared. + /// \return True if having same data address, otherwise false. + bool operator==(const CSRTensor &csr_tensor) const { return &csr_tensor == this; } bool operator==(const Value &other) const override { if (other.isa()) { @@ -710,11 +710,11 @@ class MS_CORE_API COOTensor : public MetaSparseTensor { /// \return [TensorPtr] The values. TensorPtr GetValues() { return values_; } - /// \brief Compare two tensor objects to see if they have same data type, shape and data address. + /// \brief Compare two cootensor objects to see if they have same address. /// - /// \param[in] tensor The Tensor object to be compared. - /// \return True if having same type, shape and data address, otherwise false. - bool operator==(const COOTensor &sparse_tensor) const { return &sparse_tensor == this; } + /// \param[in] coo_tensor The cootensor object to be compared. + /// \return True if having same data address, otherwise false. + bool operator==(const COOTensor &coo_tensor) const { return &coo_tensor == this; } bool operator==(const Value &other) const override { if (other.isa()) { @@ -734,6 +734,57 @@ class MS_CORE_API COOTensor : public MetaSparseTensor { TensorPtr values_; }; using COOTensorPtr = std::shared_ptr; + +// RowTensor entity class +class MS_CORE_API RowTensor : public MetaSparseTensor { + public: + abstract::AbstractBasePtr ToAbstract() override; + + /// \brief Create RowTensor with given data type from another tensor. + /// + /// \param[in] indices [Tensor] The indices. + /// \param[in] values [Tensor] The values. + /// \param[in] shape The shape represented by ShapeVector of the RowTensor. + RowTensor(const TensorPtr indices, const TensorPtr values, const ShapeVector &shape) + : MetaSparseTensor(values->data_type(), shape), indices_(indices), values_(values) {} + + /// Destructor of RowTensor. + ~RowTensor() override = default; + + /// \brief Gets RowTensor's indices. + /// + /// \return [TensorPtr] The indices. + TensorPtr GetIndices() { return indices_; } + + /// \brief Gets RowTensor's values. + /// + /// \return [TensorPtr] The values. + TensorPtr GetValues() { return values_; } + + /// \brief Compare two rowtensor objects to see if they have same address. + /// + /// \param[in] coo_tensor The rowtensor object to be compared. + /// \return True if having same data address, otherwise false. + bool operator==(const RowTensor &row_tensor) const { return &row_tensor == this; } + + bool operator==(const Value &other) const override { + if (other.isa()) { + auto &other_ = static_cast(other); + return *this == other_; + } + return false; + } + + /// \brief Get display information of this Tensor. + /// + /// \return The display information of this Tensor. + std::string ToString() const override; + + private: + TensorPtr indices_; + TensorPtr values_; +}; +using RowTensorPtr = std::shared_ptr; } // namespace tensor } // namespace mindspore diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 175058cb690..38770f018f7 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -652,6 +652,7 @@ GVAR_DEF(PrimitivePtr, kPrimFakeQuantWithMinMaxVarsPerChannel, std::make_shared("FakeQuantWithMinMaxVarsPerChannel")); // Control ops GVAR_DEF(PrimitivePtr, kPrimMerge, std::make_shared("Merge")); + // RowTensor GVAR_DEF(PrimitivePtr, kPrimMakeRowTensor, std::make_shared("MakeRowTensor")); GVAR_DEF(PrimitivePtr, kPrimRowTensorGetValues, std::make_shared("RowTensorGetValues")); @@ -671,7 +672,6 @@ GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetValues, std::make_shared(kCSR GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetIndptr, std::make_shared(kCSRTensorGetIndptr)); GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetIndices, std::make_shared(kCSRTensorGetIndices)); GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetDenseShape, std::make_shared(kCSRTensorGetDenseShape)); -GVAR_DEF(PrimitivePtr, kPrimIsCSRFunc, std::make_shared(kIsCSRFunc)); // Sparse ops GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared(kSparseTensorDenseMatmul)); diff --git a/mindspore/python/mindspore/_checkparam.py b/mindspore/python/mindspore/_checkparam.py index 8497b9f06ec..897af6d81f6 100644 --- a/mindspore/python/mindspore/_checkparam.py +++ b/mindspore/python/mindspore/_checkparam.py @@ -952,10 +952,10 @@ class Validator: raise TypeError(f"For COOTensor, the element type of 'shape' must be int, but got {type(sh)}") if len(indices_shp) != 2: raise ValueError(f"For COOTensor, 'indices' must be a 2-dimensional tensor, but got a {len(indices_shp)}" - f"dimension tensor.") + f"-dimensional tensor.") if len(values_shp) != 1: raise ValueError(f"For COOTensor, 'values' must be a 1-dimensional tensor, but got a {len(values_shp)}" - f"dimension tensor.") + f"-dimensional tensor.") if indices_shp[0] != values_shp[0]: raise ValueError(f"For COOTensor, 'indices.shape[0]' must be euqal to 'values.shape[0]', but got " f"'indices.shape[0]' = {indices_shp[0]} and 'values.shape[0]' = {values_shp[0]}.") diff --git a/mindspore/python/mindspore/common/api.py b/mindspore/python/mindspore/common/api.py index 41d3c0a8eb8..a8adc426b5b 100644 --- a/mindspore/python/mindspore/common/api.py +++ b/mindspore/python/mindspore/common/api.py @@ -32,8 +32,9 @@ from mindspore._extends.remote import kernel_build_server from .tensor import Tensor as MsTensor from .tensor import CSRTensor as MsCSRTensor from .tensor import COOTensor as MsCOOTensor +from .tensor import RowTensor as MsRowTensor from .initializer import initializer -from .._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, COOTensor, PynativeExecutor_ +from .._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, RowTensor, COOTensor, PynativeExecutor_ from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline from ..parallel._tensor import _load_tensor_by_layout from ..parallel._ps_context import _is_role_pserver, _is_role_sched @@ -65,6 +66,8 @@ def _convert_python_data(data): return MsCSRTensor(csr_tensor=data) if isinstance(data, COOTensor) and not isinstance(data, MsCOOTensor): return MsCOOTensor(coo_tensor=data) + if isinstance(data, RowTensor) and not isinstance(data, MsRowTensor): + return MsRowTensor(row_tensor=data) if isinstance(data, tuple): return tuple(_convert_python_data(x) for x in data) if isinstance(data, list): diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 21e16e95362..9b97065b658 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -23,6 +23,7 @@ from . import dtype as mstype from ._register_for_tensor import tensor_operator_registry from .._c_expression import COOTensor as COOTensor_ from .._c_expression import CSRTensor as CSRTensor_ +from .._c_expression import RowTensor as RowTensor_ from .._c_expression import Tensor as Tensor_ from .._checkparam import Rel from .._checkparam import Validator as validator @@ -2494,7 +2495,7 @@ class Tensor(Tensor_): return tensor_operator_registry.get('concatenate')(axis)(repeated_subs) -class RowTensor: +class RowTensor(RowTensor_): """ A sparse representation of a set of tensor slices at given indices. @@ -2516,10 +2517,8 @@ class RowTensor: [0, 0], [0, 0]] - RowTensor can only be used in the `Cell`'s construct method. - Note: - RowTensor is not supported in pynative mode. + This is an experimental feature and is subjected to change. Args: indices (Tensor): A 1-D integer Tensor of shape [D0]. @@ -2534,42 +2533,54 @@ class RowTensor: >>> import mindspore as ms >>> import mindspore.nn as nn >>> from mindspore import Tensor, RowTensor - >>> class Net(nn.Cell): - ... def __init__(self, dense_shape): - ... super(Net, self).__init__() - ... self.dense_shape = dense_shape - ... def construct(self, indices, values): - ... x = RowTensor(indices, values, self.dense_shape) - ... return x.values, x.indices, x.dense_shape - >>> >>> indices = Tensor([0]) >>> values = Tensor([[1, 2]], dtype=ms.float32) >>> out = Net((3, 2))(indices, values) - >>> print(out[0]) + >>> shape = (3, 2) + >>> x = RowTensor(indices, values, shape) + >>> print(x.values) [[1. 2.]] - >>> print(out[1]) + >>> print(x.indices) [0] - >>> print(out[2]) + >>> print(x.shape) (3, 2) """ - def __init__(self, indices, values, dense_shape): - "Init RowTensor" - self.__indices = indices - self.__values = values - self.__dense_shape = dense_shape + def __init__(self, indices=None, values=None, shape=None, row_tensor=None): + """Init RowTensor""" + self.init_finished = False + # Directly init a RowTensor from another RowTensor + if row_tensor is not None: + if not isinstance(row_tensor, (RowTensor, RowTensor)): + raise TypeError(f"Expect input `row_tensor` to be a RowTensor, but got {type(row_tensor)}") + if not (indices is None and values is None and shape is None): + raise TypeError("If input `row_tensor` is provided, `indices`, `values`, `shapes` should all be `None`") + RowTensor_.__init__(self, row_tensor) + # Init a RowTensor from indices, values and shape + else: + RowTensor_.__init__(self, indices, values, shape) + self.init_finished = True + + def __repr__(self): + """Avoid PyTest Segfault when RowTensor is not initialized.""" + if self.init_finished: + return RowTensor_.__repr__(self) + return '' @property def indices(self): - return self.__indices + """Return RowTensor's indices.""" + return Tensor(self._indices) @property def values(self): - return self.__values + """Return RowTensor's non-zero values.""" + return Tensor(self._values) @property def dense_shape(self): - return self.__dense_shape + """Return RowTensor's shape.""" + return self._shape class SparseTensor(COOTensor_): diff --git a/tests/st/sparse/test_row.py b/tests/st/sparse/test_row.py new file mode 100644 index 00000000000..3d93f1e72c0 --- /dev/null +++ b/tests/st/sparse/test_row.py @@ -0,0 +1,108 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""smoke tests for RowTensor operations""" + +import pytest +import numpy as np + +from mindspore import Tensor, RowTensor, nn, context +from mindspore.common import dtype as mstype + + +def compare_row(row1, row2): + assert isinstance(row1, RowTensor) + assert isinstance(row2, RowTensor) + assert (row1.indices.asnumpy() == row1.indices.asnumpy()).all() + assert (row2.values.asnumpy() == row2.values.asnumpy()).all() + assert row1.dense_shape == row2.dense_shape + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_make_row(): + """ + Feature: Test RowTensor Constructor in Graph and PyNative. + Description: Test RowTensor(indices, values, shape) and RowTensor(COOTensor) + Expectation: Success. + """ + indices = Tensor([0, 1], dtype=mstype.int32) + values = Tensor([[1, 2], [3, 4]], dtype=mstype.float32) + dense_shape = (3, 2) + + def test_pynative(): + return RowTensor(indices, values, dense_shape) + + row1 = test_pynative() + compare_row(row1, row1) + row2 = RowTensor(row_tensor=row1) + compare_row(row1, row2) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_coo_tensor_with_control_if(): + """ + Feature: Test COOTensor in if. + Description: Test COOTensor computation in while loop. + Expectation: Success. + """ + class RowTensorValuesDouble(nn.Cell): + + def construct(self, x): + indices = x.indices + values = x.values * 2 + shape = x.dense_shape + return RowTensor(indices, values, shape) + + class RowTensorValuesAdd2(nn.Cell): + + def construct(self, x): + indices = x.indices + values = x.values + 2 + shape = x.dense_shape + return RowTensor(indices, values, shape) + + class RowTensorWithControlIf(nn.Cell): + def __init__(self, shape): + super(RowTensorWithControlIf, self).__init__() + self.op1 = RowTensorValuesDouble() + self.op2 = RowTensorValuesAdd2() + self.shape = shape + + def construct(self, a, b, indices, values): + x = RowTensor(indices, values, self.shape) + if a > b: + x = self.op1(x) + else: + x = self.op2(x) + return x.indices, x.values, x.dense_shape + context.set_context(mode=context.PYNATIVE_MODE) + a = Tensor(0, mstype.int32) + b = Tensor(2, mstype.int32) + indices = Tensor([0, 1], dtype=mstype.int32) + values = Tensor([[1, 2], [3, 4]], dtype=mstype.float32) + shape = (3, 2) + net = RowTensorWithControlIf(shape) + out = net(a, b, indices, values) + assert np.allclose(out[0].asnumpy(), indices.asnumpy(), .0, .0) + assert np.allclose(out[1].asnumpy(), values.asnumpy() + 2, .0, .0) + assert out[2] == shape