forked from mindspore-Ecosystem/mindspore
support rowtensor in pynative
This commit is contained in:
parent
c29d6bb764
commit
61e1027527
|
@ -39,8 +39,9 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"__bool__", std::string("none_bool")} // C.none_bool
|
||||
}},
|
||||
{kObjectTypeFunction,
|
||||
{{"__bool__", std::string("func_bool")}, // C.str_bool
|
||||
{"__is_csr_func__", prim::kPrimIsCSRFunc}}},
|
||||
{
|
||||
{"__bool__", std::string("func_bool")} // C.str_bool
|
||||
}},
|
||||
{kNumberTypeBool,
|
||||
{
|
||||
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
|
||||
|
|
|
@ -653,8 +653,9 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) {
|
||||
auto &shape = csr_tensor.shape();
|
||||
template <typename T>
|
||||
py::tuple GetSparseTensorShape(const T &sparse_tensor) {
|
||||
auto &shape = sparse_tensor.shape();
|
||||
py::tuple dims(shape.size());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dims[i] = py::int_(shape[i]);
|
||||
|
@ -662,6 +663,8 @@ py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) {
|
|||
return dims;
|
||||
}
|
||||
|
||||
py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) { return GetSparseTensorShape(csr_tensor); }
|
||||
|
||||
REGISTER_PYBIND_DEFINE(
|
||||
CSRTensor, ([](const py::module *m) {
|
||||
// Define python CSRTensor class.
|
||||
|
@ -682,14 +685,7 @@ REGISTER_PYBIND_DEFINE(
|
|||
.def("__repr__", &CSRTensor::ToString);
|
||||
}));
|
||||
|
||||
py::tuple COOTensorPy::GetPyTupleShape(const COOTensor &coo_tensor) {
|
||||
auto &shape = coo_tensor.shape();
|
||||
py::tuple dims(shape.size());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dims[i] = py::int_(shape[i]);
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
py::tuple COOTensorPy::GetPyTupleShape(const COOTensor &coo_tensor) { return GetSparseTensorShape(coo_tensor); }
|
||||
|
||||
REGISTER_PYBIND_DEFINE(COOTensor, ([](const py::module *m) {
|
||||
// Define python COOTensor class.
|
||||
|
@ -710,5 +706,27 @@ REGISTER_PYBIND_DEFINE(COOTensor, ([](const py::module *m) {
|
|||
.def("__str__", &COOTensor::ToString)
|
||||
.def("__repr__", &COOTensor::ToString);
|
||||
}));
|
||||
|
||||
py::tuple RowTensorPy::GetPyTupleShape(const RowTensor &row_tensor) { return GetSparseTensorShape(row_tensor); }
|
||||
|
||||
REGISTER_PYBIND_DEFINE(RowTensor, ([](const py::module *m) {
|
||||
// Define python RowTensor class.
|
||||
(void)py::class_<RowTensor, std::shared_ptr<RowTensor>>(*m, "RowTensor")
|
||||
.def(py::init([](const Tensor &indices, const Tensor &values, const py::tuple &shape) {
|
||||
return std::make_shared<RowTensor>(std::make_shared<Tensor>(indices),
|
||||
std::make_shared<Tensor>(values),
|
||||
GetShapeFromTuple(shape));
|
||||
}),
|
||||
py::arg("indices"), py::arg("values"), py::arg("shape"))
|
||||
.def(py::init(
|
||||
[](const RowTensor &row_tensor) { return std::make_shared<RowTensor>(row_tensor); }),
|
||||
py::arg("input"))
|
||||
.def_property_readonly("_shape", RowTensorPy::GetPyTupleShape)
|
||||
.def_property_readonly("_dtype", &RowTensor::Dtype)
|
||||
.def_property_readonly("_indices", &RowTensor::GetIndices)
|
||||
.def_property_readonly("_values", &RowTensor::GetValues)
|
||||
.def("__str__", &RowTensor::ToString)
|
||||
.def("__repr__", &RowTensor::ToString);
|
||||
}));
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -130,6 +130,12 @@ class COOTensorPy {
|
|||
public:
|
||||
static py::tuple GetPyTupleShape(const COOTensor &coo_tensor);
|
||||
};
|
||||
|
||||
// RowTensor python wrapper and adapter class.
|
||||
class RowTensorPy {
|
||||
public:
|
||||
static py::tuple GetPyTupleShape(const RowTensor &row_tensor);
|
||||
};
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -169,8 +169,6 @@ AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -477,8 +477,8 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
auto sparse_shape = sparse->shape()->shape();
|
||||
auto dense_shape = dense->shape()->shape();
|
||||
if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs!"
|
||||
<< "but sparse tensor has " << sparse_shape.size() << " dimensions, "
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs! "
|
||||
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
|
||||
}
|
||||
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne] || dense_shape[kIndexOne] != 1) {
|
||||
|
@ -541,7 +541,7 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
|
|||
primitive->set_attr(kCSRAxis, MakeValue(axis_value));
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For CSRReduceSum, `axis` should be int32 or int64, but got "
|
||||
<< axis->BuildValue()->ToString();
|
||||
<< axis->BuildType()->ToString();
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(sparse->values()->element());
|
||||
|
|
|
@ -247,22 +247,5 @@ AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePt
|
|||
ValuePtr v = args_spec_list[0]->BuildValue();
|
||||
return std::make_shared<AbstractScalar>(!v->isa<AnyValue>());
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Statement: x.__is_csr_func__()
|
||||
// Inputs: x
|
||||
auto func = CheckArg<AbstractFunction>(primitive->name(), args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
|
||||
if (prim_func != nullptr) {
|
||||
PrimitivePtr prim = prim_func->prim();
|
||||
std::string name = prim->name();
|
||||
if (name == "S-Prim-MakeCSRTensor") {
|
||||
return std::make_shared<AbstractScalar>(1);
|
||||
}
|
||||
}
|
||||
return std::make_shared<AbstractScalar>(0);
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -154,7 +154,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimInDict, R{InferImplInDict, nullptr, true}},
|
||||
{prim::kPrimNotInDict, R{InferImplNotInDict, nullptr, true}},
|
||||
{prim::kPrimIsConsant, R{InferImplIsConstant, nullptr, true}},
|
||||
{prim::kPrimIsCSRFunc, R{InferImplIsCSRFunc, nullptr, true}},
|
||||
// Maths
|
||||
{prim::kPrimMatMul, R{InferImplMatMul, nullptr, true}},
|
||||
{prim::kPrimBatchMatMul, R{InferImplBatchMatMul, nullptr, true}},
|
||||
|
|
|
@ -110,10 +110,6 @@ MetaSparseTensor::MetaSparseTensor(TypeId data_type, const ShapeVector &shape) :
|
|||
MetaSparseTensor::MetaSparseTensor(const MetaSparseTensor &meta_sparse_tensor)
|
||||
: Value(meta_sparse_tensor), data_type_(meta_sparse_tensor.data_type()), shape_(meta_sparse_tensor.shape()) {}
|
||||
|
||||
bool MetaSparseTensor::operator==(const MetaSparseTensor &meta_sparse_tensor) const {
|
||||
return data_type_ == meta_sparse_tensor.data_type() && shape_ == meta_sparse_tensor.shape();
|
||||
}
|
||||
|
||||
TypePtr MetaSparseTensor::Dtype() const { return TypeIdToType(data_type_); }
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -269,7 +269,9 @@ class MS_CORE_API MetaSparseTensor : public Value {
|
|||
///
|
||||
/// \param[in] meta_sparse_tensor The MetaSparseTensor object to be compared.
|
||||
/// \return Return true if having same type and shape, otherwise return false.
|
||||
virtual bool operator==(const MetaSparseTensor &meta_sparse_tensor) const;
|
||||
virtual bool operator==(const MetaSparseTensor &meta_sparse_tensor) const {
|
||||
return data_type_ == meta_sparse_tensor.data_type() && shape_ == meta_sparse_tensor.shape();
|
||||
}
|
||||
|
||||
/// \brief Get the data type of the sparse tensor.
|
||||
/// All the types are defined in "ir/dtype.h".
|
||||
|
@ -297,6 +299,11 @@ class MS_CORE_API MetaSparseTensor : public Value {
|
|||
/// \param[in] shape The shape of the tensor.
|
||||
void set_shape(const ShapeVector &shape) { this->shape_ = shape; }
|
||||
|
||||
/// \brief Get display information of this Tensor.
|
||||
///
|
||||
/// \return The display information of this Tensor.
|
||||
virtual std::string ToString() const = 0;
|
||||
|
||||
protected:
|
||||
// Data type of the sparsetensor.
|
||||
TypeId data_type_;
|
||||
|
|
|
@ -930,8 +930,6 @@ TensorPtrList Tensor::GetFlattenedTensors(const TensorPtrList &tensors) {
|
|||
CSRTensor::CSRTensor(const TensorPtr indptr, const TensorPtr indices, const TensorPtr values, const ShapeVector &shape)
|
||||
: MetaSparseTensor(values->data_type(), shape), indptr_(indptr), indices_(indices), values_(values) {}
|
||||
|
||||
bool CSRTensor::operator==(const CSRTensor &csr_tensor) const { return (&csr_tensor == this); }
|
||||
|
||||
std::string CSRTensor::ToString() const {
|
||||
std::ostringstream buf;
|
||||
MS_EXCEPTION_IF_NULL(values_);
|
||||
|
@ -1009,5 +1007,34 @@ abstract::AbstractBasePtr COOTensor::ToAbstract() {
|
|||
|
||||
return abs_sparse_tensor;
|
||||
}
|
||||
|
||||
std::string RowTensor::ToString() const {
|
||||
std::ostringstream buf;
|
||||
MS_EXCEPTION_IF_NULL(indices_);
|
||||
MS_EXCEPTION_IF_NULL(values_);
|
||||
auto dtype = values_->Dtype();
|
||||
buf << "RowTensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString()
|
||||
<< ", indices=" << indices_->ToString() << ", values=" << values_->ToString() << ")";
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr RowTensor::ToAbstract() {
|
||||
auto dtype = values_->Dtype();
|
||||
if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
|
||||
MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
|
||||
}
|
||||
auto abs_sparse_tensor = std::make_shared<abstract::AbstractRowTensor>(dtype, shape_);
|
||||
|
||||
abs_sparse_tensor->set_indices(indices_->ToAbstract()->cast<abstract::AbstractTensorPtr>());
|
||||
abs_sparse_tensor->set_values(values_->ToAbstract()->cast<abstract::AbstractTensorPtr>());
|
||||
|
||||
std::vector<abstract::AbstractBasePtr> abstract_shape;
|
||||
(void)std::transform(
|
||||
shape_.begin(), shape_.end(), std::back_inserter(abstract_shape),
|
||||
[](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared<abstract::AbstractScalar>(shp); });
|
||||
abs_sparse_tensor->set_dense_shape(std::make_shared<abstract::AbstractTuple>(abstract_shape));
|
||||
|
||||
return abs_sparse_tensor;
|
||||
}
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -652,11 +652,11 @@ class MS_CORE_API CSRTensor : public MetaSparseTensor {
|
|||
/// \return [TensorPtr] The values.
|
||||
TensorPtr GetValues() { return values_; }
|
||||
|
||||
/// \brief Compare two tensor objects to see if they have same data type, shape and data address.
|
||||
/// \brief Compare two csrtensor objects to see if they have same data address.
|
||||
///
|
||||
/// \param[in] tensor The Tensor object to be compared.
|
||||
/// \return True if having same type, shape and data address, otherwise false.
|
||||
bool operator==(const CSRTensor &csr_tensor) const;
|
||||
/// \param[in] csr_tensor The csrtensor object to be compared.
|
||||
/// \return True if having same data address, otherwise false.
|
||||
bool operator==(const CSRTensor &csr_tensor) const { return &csr_tensor == this; }
|
||||
|
||||
bool operator==(const Value &other) const override {
|
||||
if (other.isa<CSRTensor>()) {
|
||||
|
@ -710,11 +710,11 @@ class MS_CORE_API COOTensor : public MetaSparseTensor {
|
|||
/// \return [TensorPtr] The values.
|
||||
TensorPtr GetValues() { return values_; }
|
||||
|
||||
/// \brief Compare two tensor objects to see if they have same data type, shape and data address.
|
||||
/// \brief Compare two cootensor objects to see if they have same address.
|
||||
///
|
||||
/// \param[in] tensor The Tensor object to be compared.
|
||||
/// \return True if having same type, shape and data address, otherwise false.
|
||||
bool operator==(const COOTensor &sparse_tensor) const { return &sparse_tensor == this; }
|
||||
/// \param[in] coo_tensor The cootensor object to be compared.
|
||||
/// \return True if having same data address, otherwise false.
|
||||
bool operator==(const COOTensor &coo_tensor) const { return &coo_tensor == this; }
|
||||
|
||||
bool operator==(const Value &other) const override {
|
||||
if (other.isa<COOTensor>()) {
|
||||
|
@ -734,6 +734,57 @@ class MS_CORE_API COOTensor : public MetaSparseTensor {
|
|||
TensorPtr values_;
|
||||
};
|
||||
using COOTensorPtr = std::shared_ptr<COOTensor>;
|
||||
|
||||
// RowTensor entity class
|
||||
class MS_CORE_API RowTensor : public MetaSparseTensor {
|
||||
public:
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
|
||||
/// \brief Create RowTensor with given data type from another tensor.
|
||||
///
|
||||
/// \param[in] indices [Tensor] The indices.
|
||||
/// \param[in] values [Tensor] The values.
|
||||
/// \param[in] shape The shape represented by ShapeVector of the RowTensor.
|
||||
RowTensor(const TensorPtr indices, const TensorPtr values, const ShapeVector &shape)
|
||||
: MetaSparseTensor(values->data_type(), shape), indices_(indices), values_(values) {}
|
||||
|
||||
/// Destructor of RowTensor.
|
||||
~RowTensor() override = default;
|
||||
|
||||
/// \brief Gets RowTensor's indices.
|
||||
///
|
||||
/// \return [TensorPtr] The indices.
|
||||
TensorPtr GetIndices() { return indices_; }
|
||||
|
||||
/// \brief Gets RowTensor's values.
|
||||
///
|
||||
/// \return [TensorPtr] The values.
|
||||
TensorPtr GetValues() { return values_; }
|
||||
|
||||
/// \brief Compare two rowtensor objects to see if they have same address.
|
||||
///
|
||||
/// \param[in] coo_tensor The rowtensor object to be compared.
|
||||
/// \return True if having same data address, otherwise false.
|
||||
bool operator==(const RowTensor &row_tensor) const { return &row_tensor == this; }
|
||||
|
||||
bool operator==(const Value &other) const override {
|
||||
if (other.isa<RowTensor>()) {
|
||||
auto &other_ = static_cast<const RowTensor &>(other);
|
||||
return *this == other_;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// \brief Get display information of this Tensor.
|
||||
///
|
||||
/// \return The display information of this Tensor.
|
||||
std::string ToString() const override;
|
||||
|
||||
private:
|
||||
TensorPtr indices_;
|
||||
TensorPtr values_;
|
||||
};
|
||||
using RowTensorPtr = std::shared_ptr<RowTensor>;
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -652,6 +652,7 @@ GVAR_DEF(PrimitivePtr, kPrimFakeQuantWithMinMaxVarsPerChannel,
|
|||
std::make_shared<Primitive>("FakeQuantWithMinMaxVarsPerChannel"));
|
||||
// Control ops
|
||||
GVAR_DEF(PrimitivePtr, kPrimMerge, std::make_shared<Primitive>("Merge"));
|
||||
|
||||
// RowTensor
|
||||
GVAR_DEF(PrimitivePtr, kPrimMakeRowTensor, std::make_shared<Primitive>("MakeRowTensor"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimRowTensorGetValues, std::make_shared<Primitive>("RowTensorGetValues"));
|
||||
|
@ -671,7 +672,6 @@ GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetValues, std::make_shared<Primitive>(kCSR
|
|||
GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetIndptr, std::make_shared<Primitive>(kCSRTensorGetIndptr));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetIndices, std::make_shared<Primitive>(kCSRTensorGetIndices));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetDenseShape, std::make_shared<Primitive>(kCSRTensorGetDenseShape));
|
||||
GVAR_DEF(PrimitivePtr, kPrimIsCSRFunc, std::make_shared<Primitive>(kIsCSRFunc));
|
||||
|
||||
// Sparse ops
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>(kSparseTensorDenseMatmul));
|
||||
|
|
|
@ -952,10 +952,10 @@ class Validator:
|
|||
raise TypeError(f"For COOTensor, the element type of 'shape' must be int, but got {type(sh)}")
|
||||
if len(indices_shp) != 2:
|
||||
raise ValueError(f"For COOTensor, 'indices' must be a 2-dimensional tensor, but got a {len(indices_shp)}"
|
||||
f"dimension tensor.")
|
||||
f"-dimensional tensor.")
|
||||
if len(values_shp) != 1:
|
||||
raise ValueError(f"For COOTensor, 'values' must be a 1-dimensional tensor, but got a {len(values_shp)}"
|
||||
f"dimension tensor.")
|
||||
f"-dimensional tensor.")
|
||||
if indices_shp[0] != values_shp[0]:
|
||||
raise ValueError(f"For COOTensor, 'indices.shape[0]' must be euqal to 'values.shape[0]', but got "
|
||||
f"'indices.shape[0]' = {indices_shp[0]} and 'values.shape[0]' = {values_shp[0]}.")
|
||||
|
|
|
@ -32,8 +32,9 @@ from mindspore._extends.remote import kernel_build_server
|
|||
from .tensor import Tensor as MsTensor
|
||||
from .tensor import CSRTensor as MsCSRTensor
|
||||
from .tensor import COOTensor as MsCOOTensor
|
||||
from .tensor import RowTensor as MsRowTensor
|
||||
from .initializer import initializer
|
||||
from .._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, COOTensor, PynativeExecutor_
|
||||
from .._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, RowTensor, COOTensor, PynativeExecutor_
|
||||
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
|
||||
from ..parallel._tensor import _load_tensor_by_layout
|
||||
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
|
@ -65,6 +66,8 @@ def _convert_python_data(data):
|
|||
return MsCSRTensor(csr_tensor=data)
|
||||
if isinstance(data, COOTensor) and not isinstance(data, MsCOOTensor):
|
||||
return MsCOOTensor(coo_tensor=data)
|
||||
if isinstance(data, RowTensor) and not isinstance(data, MsRowTensor):
|
||||
return MsRowTensor(row_tensor=data)
|
||||
if isinstance(data, tuple):
|
||||
return tuple(_convert_python_data(x) for x in data)
|
||||
if isinstance(data, list):
|
||||
|
|
|
@ -23,6 +23,7 @@ from . import dtype as mstype
|
|||
from ._register_for_tensor import tensor_operator_registry
|
||||
from .._c_expression import COOTensor as COOTensor_
|
||||
from .._c_expression import CSRTensor as CSRTensor_
|
||||
from .._c_expression import RowTensor as RowTensor_
|
||||
from .._c_expression import Tensor as Tensor_
|
||||
from .._checkparam import Rel
|
||||
from .._checkparam import Validator as validator
|
||||
|
@ -2494,7 +2495,7 @@ class Tensor(Tensor_):
|
|||
return tensor_operator_registry.get('concatenate')(axis)(repeated_subs)
|
||||
|
||||
|
||||
class RowTensor:
|
||||
class RowTensor(RowTensor_):
|
||||
"""
|
||||
A sparse representation of a set of tensor slices at given indices.
|
||||
|
||||
|
@ -2516,10 +2517,8 @@ class RowTensor:
|
|||
[0, 0],
|
||||
[0, 0]]
|
||||
|
||||
RowTensor can only be used in the `Cell`'s construct method.
|
||||
|
||||
Note:
|
||||
RowTensor is not supported in pynative mode.
|
||||
This is an experimental feature and is subjected to change.
|
||||
|
||||
Args:
|
||||
indices (Tensor): A 1-D integer Tensor of shape [D0].
|
||||
|
@ -2534,42 +2533,54 @@ class RowTensor:
|
|||
>>> import mindspore as ms
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, RowTensor
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self, dense_shape):
|
||||
... super(Net, self).__init__()
|
||||
... self.dense_shape = dense_shape
|
||||
... def construct(self, indices, values):
|
||||
... x = RowTensor(indices, values, self.dense_shape)
|
||||
... return x.values, x.indices, x.dense_shape
|
||||
>>>
|
||||
>>> indices = Tensor([0])
|
||||
>>> values = Tensor([[1, 2]], dtype=ms.float32)
|
||||
>>> out = Net((3, 2))(indices, values)
|
||||
>>> print(out[0])
|
||||
>>> shape = (3, 2)
|
||||
>>> x = RowTensor(indices, values, shape)
|
||||
>>> print(x.values)
|
||||
[[1. 2.]]
|
||||
>>> print(out[1])
|
||||
>>> print(x.indices)
|
||||
[0]
|
||||
>>> print(out[2])
|
||||
>>> print(x.shape)
|
||||
(3, 2)
|
||||
"""
|
||||
|
||||
def __init__(self, indices, values, dense_shape):
|
||||
"Init RowTensor"
|
||||
self.__indices = indices
|
||||
self.__values = values
|
||||
self.__dense_shape = dense_shape
|
||||
def __init__(self, indices=None, values=None, shape=None, row_tensor=None):
|
||||
"""Init RowTensor"""
|
||||
self.init_finished = False
|
||||
# Directly init a RowTensor from another RowTensor
|
||||
if row_tensor is not None:
|
||||
if not isinstance(row_tensor, (RowTensor, RowTensor)):
|
||||
raise TypeError(f"Expect input `row_tensor` to be a RowTensor, but got {type(row_tensor)}")
|
||||
if not (indices is None and values is None and shape is None):
|
||||
raise TypeError("If input `row_tensor` is provided, `indices`, `values`, `shapes` should all be `None`")
|
||||
RowTensor_.__init__(self, row_tensor)
|
||||
# Init a RowTensor from indices, values and shape
|
||||
else:
|
||||
RowTensor_.__init__(self, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
def __repr__(self):
|
||||
"""Avoid PyTest Segfault when RowTensor is not initialized."""
|
||||
if self.init_finished:
|
||||
return RowTensor_.__repr__(self)
|
||||
return ''
|
||||
|
||||
@property
|
||||
def indices(self):
|
||||
return self.__indices
|
||||
"""Return RowTensor's indices."""
|
||||
return Tensor(self._indices)
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
return self.__values
|
||||
"""Return RowTensor's non-zero values."""
|
||||
return Tensor(self._values)
|
||||
|
||||
@property
|
||||
def dense_shape(self):
|
||||
return self.__dense_shape
|
||||
"""Return RowTensor's shape."""
|
||||
return self._shape
|
||||
|
||||
|
||||
class SparseTensor(COOTensor_):
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""smoke tests for RowTensor operations"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, RowTensor, nn, context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
def compare_row(row1, row2):
|
||||
assert isinstance(row1, RowTensor)
|
||||
assert isinstance(row2, RowTensor)
|
||||
assert (row1.indices.asnumpy() == row1.indices.asnumpy()).all()
|
||||
assert (row2.values.asnumpy() == row2.values.asnumpy()).all()
|
||||
assert row1.dense_shape == row2.dense_shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_make_row():
|
||||
"""
|
||||
Feature: Test RowTensor Constructor in Graph and PyNative.
|
||||
Description: Test RowTensor(indices, values, shape) and RowTensor(COOTensor)
|
||||
Expectation: Success.
|
||||
"""
|
||||
indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
values = Tensor([[1, 2], [3, 4]], dtype=mstype.float32)
|
||||
dense_shape = (3, 2)
|
||||
|
||||
def test_pynative():
|
||||
return RowTensor(indices, values, dense_shape)
|
||||
|
||||
row1 = test_pynative()
|
||||
compare_row(row1, row1)
|
||||
row2 = RowTensor(row_tensor=row1)
|
||||
compare_row(row1, row2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_coo_tensor_with_control_if():
|
||||
"""
|
||||
Feature: Test COOTensor in if.
|
||||
Description: Test COOTensor computation in while loop.
|
||||
Expectation: Success.
|
||||
"""
|
||||
class RowTensorValuesDouble(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
indices = x.indices
|
||||
values = x.values * 2
|
||||
shape = x.dense_shape
|
||||
return RowTensor(indices, values, shape)
|
||||
|
||||
class RowTensorValuesAdd2(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
indices = x.indices
|
||||
values = x.values + 2
|
||||
shape = x.dense_shape
|
||||
return RowTensor(indices, values, shape)
|
||||
|
||||
class RowTensorWithControlIf(nn.Cell):
|
||||
def __init__(self, shape):
|
||||
super(RowTensorWithControlIf, self).__init__()
|
||||
self.op1 = RowTensorValuesDouble()
|
||||
self.op2 = RowTensorValuesAdd2()
|
||||
self.shape = shape
|
||||
|
||||
def construct(self, a, b, indices, values):
|
||||
x = RowTensor(indices, values, self.shape)
|
||||
if a > b:
|
||||
x = self.op1(x)
|
||||
else:
|
||||
x = self.op2(x)
|
||||
return x.indices, x.values, x.dense_shape
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
a = Tensor(0, mstype.int32)
|
||||
b = Tensor(2, mstype.int32)
|
||||
indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
values = Tensor([[1, 2], [3, 4]], dtype=mstype.float32)
|
||||
shape = (3, 2)
|
||||
net = RowTensorWithControlIf(shape)
|
||||
out = net(a, b, indices, values)
|
||||
assert np.allclose(out[0].asnumpy(), indices.asnumpy(), .0, .0)
|
||||
assert np.allclose(out[1].asnumpy(), values.asnumpy() + 2, .0, .0)
|
||||
assert out[2] == shape
|
Loading…
Reference in New Issue