support rowtensor in pynative

This commit is contained in:
yanglf1121 2022-04-21 09:50:59 +08:00
parent c29d6bb764
commit 61e1027527
16 changed files with 285 additions and 77 deletions

View File

@ -39,8 +39,9 @@ BuiltInTypeMap &GetMethodMap() {
{"__bool__", std::string("none_bool")} // C.none_bool
}},
{kObjectTypeFunction,
{{"__bool__", std::string("func_bool")}, // C.str_bool
{"__is_csr_func__", prim::kPrimIsCSRFunc}}},
{
{"__bool__", std::string("func_bool")} // C.str_bool
}},
{kNumberTypeBool,
{
{"__and__", prim::kPrimBoolAnd}, // P.bool_and

View File

@ -653,8 +653,9 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
}));
}));
py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) {
auto &shape = csr_tensor.shape();
template <typename T>
py::tuple GetSparseTensorShape(const T &sparse_tensor) {
auto &shape = sparse_tensor.shape();
py::tuple dims(shape.size());
for (size_t i = 0; i < dims.size(); ++i) {
dims[i] = py::int_(shape[i]);
@ -662,6 +663,8 @@ py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) {
return dims;
}
py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) { return GetSparseTensorShape(csr_tensor); }
REGISTER_PYBIND_DEFINE(
CSRTensor, ([](const py::module *m) {
// Define python CSRTensor class.
@ -682,14 +685,7 @@ REGISTER_PYBIND_DEFINE(
.def("__repr__", &CSRTensor::ToString);
}));
py::tuple COOTensorPy::GetPyTupleShape(const COOTensor &coo_tensor) {
auto &shape = coo_tensor.shape();
py::tuple dims(shape.size());
for (size_t i = 0; i < dims.size(); ++i) {
dims[i] = py::int_(shape[i]);
}
return dims;
}
py::tuple COOTensorPy::GetPyTupleShape(const COOTensor &coo_tensor) { return GetSparseTensorShape(coo_tensor); }
REGISTER_PYBIND_DEFINE(COOTensor, ([](const py::module *m) {
// Define python COOTensor class.
@ -710,5 +706,27 @@ REGISTER_PYBIND_DEFINE(COOTensor, ([](const py::module *m) {
.def("__str__", &COOTensor::ToString)
.def("__repr__", &COOTensor::ToString);
}));
py::tuple RowTensorPy::GetPyTupleShape(const RowTensor &row_tensor) { return GetSparseTensorShape(row_tensor); }
REGISTER_PYBIND_DEFINE(RowTensor, ([](const py::module *m) {
// Define python RowTensor class.
(void)py::class_<RowTensor, std::shared_ptr<RowTensor>>(*m, "RowTensor")
.def(py::init([](const Tensor &indices, const Tensor &values, const py::tuple &shape) {
return std::make_shared<RowTensor>(std::make_shared<Tensor>(indices),
std::make_shared<Tensor>(values),
GetShapeFromTuple(shape));
}),
py::arg("indices"), py::arg("values"), py::arg("shape"))
.def(py::init(
[](const RowTensor &row_tensor) { return std::make_shared<RowTensor>(row_tensor); }),
py::arg("input"))
.def_property_readonly("_shape", RowTensorPy::GetPyTupleShape)
.def_property_readonly("_dtype", &RowTensor::Dtype)
.def_property_readonly("_indices", &RowTensor::GetIndices)
.def_property_readonly("_values", &RowTensor::GetValues)
.def("__str__", &RowTensor::ToString)
.def("__repr__", &RowTensor::ToString);
}));
} // namespace tensor
} // namespace mindspore

View File

@ -130,6 +130,12 @@ class COOTensorPy {
public:
static py::tuple GetPyTupleShape(const COOTensor &coo_tensor);
};
// RowTensor python wrapper and adapter class.
class RowTensorPy {
public:
static py::tuple GetPyTupleShape(const RowTensor &row_tensor);
};
} // namespace tensor
} // namespace mindspore

View File

@ -169,8 +169,6 @@ AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -477,8 +477,8 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
auto sparse_shape = sparse->shape()->shape();
auto dense_shape = dense->shape()->shape();
if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) {
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs!"
<< "but sparse tensor has " << sparse_shape.size() << " dimensions, "
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs! "
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
}
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne] || dense_shape[kIndexOne] != 1) {
@ -541,7 +541,7 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
primitive->set_attr(kCSRAxis, MakeValue(axis_value));
} else {
MS_EXCEPTION(TypeError) << "For CSRReduceSum, `axis` should be int32 or int64, but got "
<< axis->BuildValue()->ToString();
<< axis->BuildType()->ToString();
}
MS_EXCEPTION_IF_NULL(sparse->values()->element());

View File

@ -247,22 +247,5 @@ AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePt
ValuePtr v = args_spec_list[0]->BuildValue();
return std::make_shared<AbstractScalar>(!v->isa<AnyValue>());
}
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Statement: x.__is_csr_func__()
// Inputs: x
auto func = CheckArg<AbstractFunction>(primitive->name(), args_spec_list, 0);
MS_EXCEPTION_IF_NULL(func);
auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
if (prim_func != nullptr) {
PrimitivePtr prim = prim_func->prim();
std::string name = prim->name();
if (name == "S-Prim-MakeCSRTensor") {
return std::make_shared<AbstractScalar>(1);
}
}
return std::make_shared<AbstractScalar>(0);
}
} // namespace abstract
} // namespace mindspore

View File

@ -154,7 +154,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimInDict, R{InferImplInDict, nullptr, true}},
{prim::kPrimNotInDict, R{InferImplNotInDict, nullptr, true}},
{prim::kPrimIsConsant, R{InferImplIsConstant, nullptr, true}},
{prim::kPrimIsCSRFunc, R{InferImplIsCSRFunc, nullptr, true}},
// Maths
{prim::kPrimMatMul, R{InferImplMatMul, nullptr, true}},
{prim::kPrimBatchMatMul, R{InferImplBatchMatMul, nullptr, true}},

View File

@ -110,10 +110,6 @@ MetaSparseTensor::MetaSparseTensor(TypeId data_type, const ShapeVector &shape) :
MetaSparseTensor::MetaSparseTensor(const MetaSparseTensor &meta_sparse_tensor)
: Value(meta_sparse_tensor), data_type_(meta_sparse_tensor.data_type()), shape_(meta_sparse_tensor.shape()) {}
bool MetaSparseTensor::operator==(const MetaSparseTensor &meta_sparse_tensor) const {
return data_type_ == meta_sparse_tensor.data_type() && shape_ == meta_sparse_tensor.shape();
}
TypePtr MetaSparseTensor::Dtype() const { return TypeIdToType(data_type_); }
} // namespace tensor
} // namespace mindspore

View File

@ -269,7 +269,9 @@ class MS_CORE_API MetaSparseTensor : public Value {
///
/// \param[in] meta_sparse_tensor The MetaSparseTensor object to be compared.
/// \return Return true if having same type and shape, otherwise return false.
virtual bool operator==(const MetaSparseTensor &meta_sparse_tensor) const;
virtual bool operator==(const MetaSparseTensor &meta_sparse_tensor) const {
return data_type_ == meta_sparse_tensor.data_type() && shape_ == meta_sparse_tensor.shape();
}
/// \brief Get the data type of the sparse tensor.
/// All the types are defined in "ir/dtype.h".
@ -297,6 +299,11 @@ class MS_CORE_API MetaSparseTensor : public Value {
/// \param[in] shape The shape of the tensor.
void set_shape(const ShapeVector &shape) { this->shape_ = shape; }
/// \brief Get display information of this Tensor.
///
/// \return The display information of this Tensor.
virtual std::string ToString() const = 0;
protected:
// Data type of the sparsetensor.
TypeId data_type_;

View File

@ -930,8 +930,6 @@ TensorPtrList Tensor::GetFlattenedTensors(const TensorPtrList &tensors) {
CSRTensor::CSRTensor(const TensorPtr indptr, const TensorPtr indices, const TensorPtr values, const ShapeVector &shape)
: MetaSparseTensor(values->data_type(), shape), indptr_(indptr), indices_(indices), values_(values) {}
bool CSRTensor::operator==(const CSRTensor &csr_tensor) const { return (&csr_tensor == this); }
std::string CSRTensor::ToString() const {
std::ostringstream buf;
MS_EXCEPTION_IF_NULL(values_);
@ -1009,5 +1007,34 @@ abstract::AbstractBasePtr COOTensor::ToAbstract() {
return abs_sparse_tensor;
}
std::string RowTensor::ToString() const {
std::ostringstream buf;
MS_EXCEPTION_IF_NULL(indices_);
MS_EXCEPTION_IF_NULL(values_);
auto dtype = values_->Dtype();
buf << "RowTensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString()
<< ", indices=" << indices_->ToString() << ", values=" << values_->ToString() << ")";
return buf.str();
}
abstract::AbstractBasePtr RowTensor::ToAbstract() {
auto dtype = values_->Dtype();
if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
}
auto abs_sparse_tensor = std::make_shared<abstract::AbstractRowTensor>(dtype, shape_);
abs_sparse_tensor->set_indices(indices_->ToAbstract()->cast<abstract::AbstractTensorPtr>());
abs_sparse_tensor->set_values(values_->ToAbstract()->cast<abstract::AbstractTensorPtr>());
std::vector<abstract::AbstractBasePtr> abstract_shape;
(void)std::transform(
shape_.begin(), shape_.end(), std::back_inserter(abstract_shape),
[](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared<abstract::AbstractScalar>(shp); });
abs_sparse_tensor->set_dense_shape(std::make_shared<abstract::AbstractTuple>(abstract_shape));
return abs_sparse_tensor;
}
} // namespace tensor
} // namespace mindspore

View File

@ -652,11 +652,11 @@ class MS_CORE_API CSRTensor : public MetaSparseTensor {
/// \return [TensorPtr] The values.
TensorPtr GetValues() { return values_; }
/// \brief Compare two tensor objects to see if they have same data type, shape and data address.
/// \brief Compare two csrtensor objects to see if they have same data address.
///
/// \param[in] tensor The Tensor object to be compared.
/// \return True if having same type, shape and data address, otherwise false.
bool operator==(const CSRTensor &csr_tensor) const;
/// \param[in] csr_tensor The csrtensor object to be compared.
/// \return True if having same data address, otherwise false.
bool operator==(const CSRTensor &csr_tensor) const { return &csr_tensor == this; }
bool operator==(const Value &other) const override {
if (other.isa<CSRTensor>()) {
@ -710,11 +710,11 @@ class MS_CORE_API COOTensor : public MetaSparseTensor {
/// \return [TensorPtr] The values.
TensorPtr GetValues() { return values_; }
/// \brief Compare two tensor objects to see if they have same data type, shape and data address.
/// \brief Compare two cootensor objects to see if they have same address.
///
/// \param[in] tensor The Tensor object to be compared.
/// \return True if having same type, shape and data address, otherwise false.
bool operator==(const COOTensor &sparse_tensor) const { return &sparse_tensor == this; }
/// \param[in] coo_tensor The cootensor object to be compared.
/// \return True if having same data address, otherwise false.
bool operator==(const COOTensor &coo_tensor) const { return &coo_tensor == this; }
bool operator==(const Value &other) const override {
if (other.isa<COOTensor>()) {
@ -734,6 +734,57 @@ class MS_CORE_API COOTensor : public MetaSparseTensor {
TensorPtr values_;
};
using COOTensorPtr = std::shared_ptr<COOTensor>;
// RowTensor entity class
class MS_CORE_API RowTensor : public MetaSparseTensor {
public:
abstract::AbstractBasePtr ToAbstract() override;
/// \brief Create RowTensor with given data type from another tensor.
///
/// \param[in] indices [Tensor] The indices.
/// \param[in] values [Tensor] The values.
/// \param[in] shape The shape represented by ShapeVector of the RowTensor.
RowTensor(const TensorPtr indices, const TensorPtr values, const ShapeVector &shape)
: MetaSparseTensor(values->data_type(), shape), indices_(indices), values_(values) {}
/// Destructor of RowTensor.
~RowTensor() override = default;
/// \brief Gets RowTensor's indices.
///
/// \return [TensorPtr] The indices.
TensorPtr GetIndices() { return indices_; }
/// \brief Gets RowTensor's values.
///
/// \return [TensorPtr] The values.
TensorPtr GetValues() { return values_; }
/// \brief Compare two rowtensor objects to see if they have same address.
///
/// \param[in] coo_tensor The rowtensor object to be compared.
/// \return True if having same data address, otherwise false.
bool operator==(const RowTensor &row_tensor) const { return &row_tensor == this; }
bool operator==(const Value &other) const override {
if (other.isa<RowTensor>()) {
auto &other_ = static_cast<const RowTensor &>(other);
return *this == other_;
}
return false;
}
/// \brief Get display information of this Tensor.
///
/// \return The display information of this Tensor.
std::string ToString() const override;
private:
TensorPtr indices_;
TensorPtr values_;
};
using RowTensorPtr = std::shared_ptr<RowTensor>;
} // namespace tensor
} // namespace mindspore

View File

@ -652,6 +652,7 @@ GVAR_DEF(PrimitivePtr, kPrimFakeQuantWithMinMaxVarsPerChannel,
std::make_shared<Primitive>("FakeQuantWithMinMaxVarsPerChannel"));
// Control ops
GVAR_DEF(PrimitivePtr, kPrimMerge, std::make_shared<Primitive>("Merge"));
// RowTensor
GVAR_DEF(PrimitivePtr, kPrimMakeRowTensor, std::make_shared<Primitive>("MakeRowTensor"));
GVAR_DEF(PrimitivePtr, kPrimRowTensorGetValues, std::make_shared<Primitive>("RowTensorGetValues"));
@ -671,7 +672,6 @@ GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetValues, std::make_shared<Primitive>(kCSR
GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetIndptr, std::make_shared<Primitive>(kCSRTensorGetIndptr));
GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetIndices, std::make_shared<Primitive>(kCSRTensorGetIndices));
GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetDenseShape, std::make_shared<Primitive>(kCSRTensorGetDenseShape));
GVAR_DEF(PrimitivePtr, kPrimIsCSRFunc, std::make_shared<Primitive>(kIsCSRFunc));
// Sparse ops
GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>(kSparseTensorDenseMatmul));

View File

@ -952,10 +952,10 @@ class Validator:
raise TypeError(f"For COOTensor, the element type of 'shape' must be int, but got {type(sh)}")
if len(indices_shp) != 2:
raise ValueError(f"For COOTensor, 'indices' must be a 2-dimensional tensor, but got a {len(indices_shp)}"
f"dimension tensor.")
f"-dimensional tensor.")
if len(values_shp) != 1:
raise ValueError(f"For COOTensor, 'values' must be a 1-dimensional tensor, but got a {len(values_shp)}"
f"dimension tensor.")
f"-dimensional tensor.")
if indices_shp[0] != values_shp[0]:
raise ValueError(f"For COOTensor, 'indices.shape[0]' must be euqal to 'values.shape[0]', but got "
f"'indices.shape[0]' = {indices_shp[0]} and 'values.shape[0]' = {values_shp[0]}.")

View File

@ -32,8 +32,9 @@ from mindspore._extends.remote import kernel_build_server
from .tensor import Tensor as MsTensor
from .tensor import CSRTensor as MsCSRTensor
from .tensor import COOTensor as MsCOOTensor
from .tensor import RowTensor as MsRowTensor
from .initializer import initializer
from .._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, COOTensor, PynativeExecutor_
from .._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, RowTensor, COOTensor, PynativeExecutor_
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
from ..parallel._tensor import _load_tensor_by_layout
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
@ -65,6 +66,8 @@ def _convert_python_data(data):
return MsCSRTensor(csr_tensor=data)
if isinstance(data, COOTensor) and not isinstance(data, MsCOOTensor):
return MsCOOTensor(coo_tensor=data)
if isinstance(data, RowTensor) and not isinstance(data, MsRowTensor):
return MsRowTensor(row_tensor=data)
if isinstance(data, tuple):
return tuple(_convert_python_data(x) for x in data)
if isinstance(data, list):

View File

@ -23,6 +23,7 @@ from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry
from .._c_expression import COOTensor as COOTensor_
from .._c_expression import CSRTensor as CSRTensor_
from .._c_expression import RowTensor as RowTensor_
from .._c_expression import Tensor as Tensor_
from .._checkparam import Rel
from .._checkparam import Validator as validator
@ -2494,7 +2495,7 @@ class Tensor(Tensor_):
return tensor_operator_registry.get('concatenate')(axis)(repeated_subs)
class RowTensor:
class RowTensor(RowTensor_):
"""
A sparse representation of a set of tensor slices at given indices.
@ -2516,10 +2517,8 @@ class RowTensor:
[0, 0],
[0, 0]]
RowTensor can only be used in the `Cell`'s construct method.
Note:
RowTensor is not supported in pynative mode.
This is an experimental feature and is subjected to change.
Args:
indices (Tensor): A 1-D integer Tensor of shape [D0].
@ -2534,42 +2533,54 @@ class RowTensor:
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, RowTensor
>>> class Net(nn.Cell):
... def __init__(self, dense_shape):
... super(Net, self).__init__()
... self.dense_shape = dense_shape
... def construct(self, indices, values):
... x = RowTensor(indices, values, self.dense_shape)
... return x.values, x.indices, x.dense_shape
>>>
>>> indices = Tensor([0])
>>> values = Tensor([[1, 2]], dtype=ms.float32)
>>> out = Net((3, 2))(indices, values)
>>> print(out[0])
>>> shape = (3, 2)
>>> x = RowTensor(indices, values, shape)
>>> print(x.values)
[[1. 2.]]
>>> print(out[1])
>>> print(x.indices)
[0]
>>> print(out[2])
>>> print(x.shape)
(3, 2)
"""
def __init__(self, indices, values, dense_shape):
"Init RowTensor"
self.__indices = indices
self.__values = values
self.__dense_shape = dense_shape
def __init__(self, indices=None, values=None, shape=None, row_tensor=None):
"""Init RowTensor"""
self.init_finished = False
# Directly init a RowTensor from another RowTensor
if row_tensor is not None:
if not isinstance(row_tensor, (RowTensor, RowTensor)):
raise TypeError(f"Expect input `row_tensor` to be a RowTensor, but got {type(row_tensor)}")
if not (indices is None and values is None and shape is None):
raise TypeError("If input `row_tensor` is provided, `indices`, `values`, `shapes` should all be `None`")
RowTensor_.__init__(self, row_tensor)
# Init a RowTensor from indices, values and shape
else:
RowTensor_.__init__(self, indices, values, shape)
self.init_finished = True
def __repr__(self):
"""Avoid PyTest Segfault when RowTensor is not initialized."""
if self.init_finished:
return RowTensor_.__repr__(self)
return ''
@property
def indices(self):
return self.__indices
"""Return RowTensor's indices."""
return Tensor(self._indices)
@property
def values(self):
return self.__values
"""Return RowTensor's non-zero values."""
return Tensor(self._values)
@property
def dense_shape(self):
return self.__dense_shape
"""Return RowTensor's shape."""
return self._shape
class SparseTensor(COOTensor_):

108
tests/st/sparse/test_row.py Normal file
View File

@ -0,0 +1,108 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""smoke tests for RowTensor operations"""
import pytest
import numpy as np
from mindspore import Tensor, RowTensor, nn, context
from mindspore.common import dtype as mstype
def compare_row(row1, row2):
assert isinstance(row1, RowTensor)
assert isinstance(row2, RowTensor)
assert (row1.indices.asnumpy() == row1.indices.asnumpy()).all()
assert (row2.values.asnumpy() == row2.values.asnumpy()).all()
assert row1.dense_shape == row2.dense_shape
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_make_row():
"""
Feature: Test RowTensor Constructor in Graph and PyNative.
Description: Test RowTensor(indices, values, shape) and RowTensor(COOTensor)
Expectation: Success.
"""
indices = Tensor([0, 1], dtype=mstype.int32)
values = Tensor([[1, 2], [3, 4]], dtype=mstype.float32)
dense_shape = (3, 2)
def test_pynative():
return RowTensor(indices, values, dense_shape)
row1 = test_pynative()
compare_row(row1, row1)
row2 = RowTensor(row_tensor=row1)
compare_row(row1, row2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_coo_tensor_with_control_if():
"""
Feature: Test COOTensor in if.
Description: Test COOTensor computation in while loop.
Expectation: Success.
"""
class RowTensorValuesDouble(nn.Cell):
def construct(self, x):
indices = x.indices
values = x.values * 2
shape = x.dense_shape
return RowTensor(indices, values, shape)
class RowTensorValuesAdd2(nn.Cell):
def construct(self, x):
indices = x.indices
values = x.values + 2
shape = x.dense_shape
return RowTensor(indices, values, shape)
class RowTensorWithControlIf(nn.Cell):
def __init__(self, shape):
super(RowTensorWithControlIf, self).__init__()
self.op1 = RowTensorValuesDouble()
self.op2 = RowTensorValuesAdd2()
self.shape = shape
def construct(self, a, b, indices, values):
x = RowTensor(indices, values, self.shape)
if a > b:
x = self.op1(x)
else:
x = self.op2(x)
return x.indices, x.values, x.dense_shape
context.set_context(mode=context.PYNATIVE_MODE)
a = Tensor(0, mstype.int32)
b = Tensor(2, mstype.int32)
indices = Tensor([0, 1], dtype=mstype.int32)
values = Tensor([[1, 2], [3, 4]], dtype=mstype.float32)
shape = (3, 2)
net = RowTensorWithControlIf(shape)
out = net(a, b, indices, values)
assert np.allclose(out[0].asnumpy(), indices.asnumpy(), .0, .0)
assert np.allclose(out[1].asnumpy(), values.asnumpy() + 2, .0, .0)
assert out[2] == shape