support sparse tensor frontend

This commit is contained in:
yanglf1121 2021-10-13 15:15:07 +08:00
parent 6bdd38399a
commit 72db8e4d3f
38 changed files with 964 additions and 36 deletions

View File

@ -18,7 +18,7 @@
import ast
import math
from mindspore import RowTensor, SparseTensor
from mindspore import RowTensor, SparseTensor, CSRTensor
from mindspore.ops import functional as F, composite as C
from mindspore.ops.composite import multitype_ops
from mindspore._c_expression import security
@ -140,6 +140,7 @@ convert_object_map = {
# user defined
RowTensor: F.make_row_tensor,
SparseTensor: F.make_sparse_tensor,
CSRTensor: F.make_csr_tensor
}
if not security.enable_security():

View File

@ -32,6 +32,7 @@
#include "backend/optimizer/pass/reduce_sum_optimizer.h"
#include "backend/optimizer/pass/add_dynamic_shape_attr.h"
#include "backend/optimizer/pass/add_akg_kernel_attrs.h"
#include "backend/optimizer/pass/sparse_process.h"
#include "utils/ms_context.h"
#include "debug/anf_ir_dump.h"
@ -51,6 +52,7 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
#endif
auto optimizer = std::make_shared<GraphOptimizer>();
auto common_pm = std::make_shared<PassManager>("common_pm");
common_pm->AddPass(std::make_shared<SparseProcess>());
common_pm->AddPass(std::make_shared<AddDynamicShapeAttr>());
common_pm->AddPass(std::make_shared<ReduceSumOptimizer>());
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());

View File

@ -29,6 +29,7 @@
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "frontend/operator/ops.h"
#include "utils/ms_utils.h"
#include "utils/convert_utils.h"
#include "runtime/device/kernel_info.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/const_input_to_attr_registry.h"

View File

@ -0,0 +1,77 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "backend/optimizer/pass/sparse_process.h"
#include "ir/anf.h"
#include "utils/convert_utils.h"
#include "utils/anf_utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
std::string prim_name = prim->name();
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
// cnode is a MakeSparse node
if (make_sparse_set.find(prim_name) != make_sparse_set.end()) {
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
// Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items;
(void)inputs.insert(inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
auto new_node = cnode->func_graph()->NewCNode(inputs);
auto abs_sparse = dyn_cast<abstract::AbstractCSRTensor>(node->abstract());
std::vector<AbstractBasePtr> abstract_list{abs_sparse->indptr(), abs_sparse->indices(), abs_sparse->values(),
abs_sparse->dense_shape()};
auto abs_res = std::make_shared<abstract::AbstractTuple>(abstract_list);
new_node->set_abstract(abs_res);
new_node->set_scope(cnode->scope());
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(cnode, new_node);
}
return new_node;
// cnode is a SparseGetAttr node
} else if (sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
const auto &inputs = cnode->inputs();
// Inputs should be [sparse_getattr, sparse]
constexpr size_t sparse_index = 1;
AnfNodePtr sparse = inputs[sparse_index];
MS_EXCEPTION_IF_NULL(sparse);
int64_t index = sparse_attr_map.at(prim_name);
auto cons_node = NewValueNode(index);
AbstractBasePtr aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
cons_node->set_abstract(aptr);
auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node}, func_graph);
new_node->set_abstract(node->abstract());
return new_node;
}
return nullptr;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SPARSE_PROCESS_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SPARSE_PROCESS_H_
#include <string>
#include "ir/anf.h"
#include "utils/convert_utils.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class SparseProcess : public PatternProcessPass {
public:
explicit SparseProcess(bool multigraph = true) : PatternProcessPass("sparse_process", multigraph) {}
~SparseProcess() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
const size_t kAnfPrimitiveIndex = 0;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SPARSE_PROCESS_H_

View File

@ -334,9 +334,14 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputByCallNode(const K
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) {
std::vector<KernelWithIndex> ret;
std::vector<KernelWithIndex> ret_empty;
// The makeTuple node need expand and recurse.
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
const PrimitiveSet expand_prims{
prim::kPrimMakeTuple,
prim::kPrimMakeCSRTensor,
prim::kPrimMakeSparseTensor,
prim::kPrimMakeRowTensor,
};
// The MakeTuple/MakeSparse node need expand and recurse.
if (IsOneOfPrimitiveCNode(node, expand_prims)) {
auto make_tuple = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
@ -385,8 +390,8 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
AnfAlgo::VisitKernelWithReturnType(node, i, false, {prim::kPrimMakeTuple, prim::kPrimUpdateState});
MS_EXCEPTION_IF_NULL(output_with_index.first);
// The makeTuple node need recurse.
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) {
// The MakeTuple/MakeSparse node need recurse.
if (IsOneOfPrimitiveCNode(output_with_index.first, expand_prims)) {
auto output_vector = GetAllOutputWithIndex(output_with_index.first);
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
continue;

View File

@ -45,6 +45,7 @@ class ProtoExporter {
std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, size_t> &apply_map,
std::map<AnfNodePtr, size_t> *const_map_ptr);
void SetValueToProtoBasicTypes(const ValuePtr &attr_value, irpb::ValueProto *value_proto);
void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto);
void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto);
void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto);
@ -112,11 +113,12 @@ void CheckIfValidType(const TypePtr &type) {
MS_EXCEPTION_IF_NULL(type);
if (type->isa<Problem>()) {
MS_LOG(WARNING) << "The type: " << type->type_name();
return;
}
if (!(type->isa<Number>() || type->isa<TensorType>() || type->isa<Tuple>() || type->isa<TypeType>() ||
type->isa<List>() || type->isa<TypeAnything>() || type->isa<RefKeyType>() || type->isa<RefType>() ||
type->isa<Function>() || type->isa<TypeNone>() || type->isa<Problem>() || type->isa<String>() ||
type->isa<RowTensorType>() || type->isa<UndeterminedType>() || type->isa<SparseTensorType>() ||
type->isa<Function>() || type->isa<TypeNone>() || type->isa<String>() || type->isa<RowTensorType>() ||
type->isa<CSRTensorType>() || type->isa<UndeterminedType>() || type->isa<SparseTensorType>() ||
type->isa<SymbolicKeyType>() || type->isa<MonadType>())) {
MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
}
@ -126,12 +128,12 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
if (type_proto == nullptr) {
return;
}
if (type != nullptr) {
CheckIfValidType(type);
}
if (type == nullptr) {
type_proto->set_data_type(irpb::DT_UNDEFINED);
} else if (type->isa<Number>()) {
return;
}
CheckIfValidType(type);
if (type->isa<Number>()) {
type_proto->set_data_type(GetNumberDataType(type));
} else if (type->isa<TensorType>()) {
TypePtr elem_type = dyn_cast<TensorType>(type)->element();
@ -179,11 +181,7 @@ void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *t
SetNodeOutputType(node->Type(), node->Shape(), type_proto);
}
void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) {
return;
}
void ProtoExporter::SetValueToProtoBasicTypes(const ValuePtr &val, irpb::ValueProto *value_proto) {
if (val->isa<StringImm>()) {
const StringImmPtr &value = dyn_cast<StringImm>(val);
value_proto->set_dtype(irpb::DT_STRING);
@ -202,7 +200,17 @@ void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value
} else if (val->isa<Float>()) {
value_proto->set_dtype(irpb::DT_TYPE);
value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_FLOAT);
} else if (val->isa<ValueSequeue>()) {
}
}
void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) {
return;
}
SetValueToProtoBasicTypes(val, value_proto);
if (val->isa<ValueSequeue>()) {
SetSequenceToProto(dyn_cast<ValueSequeue>(val), value_proto);
} else if (val->isa<None>()) {
value_proto->set_dtype(irpb::DT_NONE);

View File

@ -1138,9 +1138,11 @@ py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_o
MS_LOG(INFO) << "VM loop size " << vm_loop << ", loopsink size " << vm_loop;
py::object ret;
MS_LOG(DEBUG) << "Eval run" << backend;
auto output = execute_info->func_graph->output()->abstract();
MS_EXCEPTION_IF_NULL(output);
for (int64_t i = 0; i < vm_loop; i++) {
BaseRef value = (*run)(execute_info->arg_list);
ret = BaseRefToPyData(value);
ret = BaseRefToPyData(value, output);
}
MS_LOG(DEBUG) << "Run end";
return ret;

View File

@ -241,6 +241,13 @@ BuiltInTypeMap &GetAttrMap() {
{"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices
{"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape
}},
{kObjectTypeCSRTensorType,
{
{"indptr", prim::kPrimCSRTensorGetIndptr}, // F.csr_tensor_get_indptr
{"values", prim::kPrimCSRTensorGetValues}, // F.csr_tensor_get_values
{"indices", prim::kPrimCSRTensorGetIndices}, // F.csr_tensor_get_indices
{"shape", prim::kPrimCSRTensorGetDenseShape}, // F.csr_tensor_get_shape
}},
};
return attr_map;
}

View File

@ -462,6 +462,11 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic[ATTR_SHAPE] = arg->shape()->shape();
dic[ATTR_DTYPE] = arg->BuildType();
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
} else if (abs_base->isa<AbstractCSRTensor>()) {
auto arg = dyn_cast<AbstractCSRTensor>(abs_base);
dic[ATTR_SHAPE] = arg->shape()->shape();
dic[ATTR_DTYPE] = arg->BuildType();
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
ShapeVector shape;
dic[ATTR_SHAPE] = shape;

View File

@ -30,6 +30,7 @@ namespace mindspore {
namespace validator {
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractClass;
using mindspore::abstract::AbstractCSRTensor;
using mindspore::abstract::AbstractError;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractJTagged;
@ -114,11 +115,12 @@ void ValidateAbstract(const AnfNodePtr &node) {
MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString();
return;
}
bool is_legal_abstract =
abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() || abstract->isa<AbstractTuple>() ||
abstract->isa<AbstractList>() || abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
abstract->isa<AbstractSparseTensor>() || abstract->isa<abstract::AbstractRefKey>() ||
abstract->isa<AbstractRef>() || abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
bool is_legal_abstract = abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() ||
abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
abstract->isa<AbstractSparseTensor>() || abstract->isa<AbstractCSRTensor>() ||
abstract->isa<abstract::AbstractRefKey>() || abstract->isa<AbstractRef>() ||
abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
if (is_legal_abstract) {
return;
}

View File

@ -154,6 +154,7 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<RowTensorType, Type, std::shared_ptr<RowTensorType>>(m_sub, "RowTensorType").def(py::init());
(void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType")
.def(py::init());
(void)py::class_<CSRTensorType, Type, std::shared_ptr<CSRTensorType>>(m_sub, "CSRTensorType").def(py::init());
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
.def(py::init());
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")

View File

@ -640,5 +640,34 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
return TensorPy::MakeTensor(t[0].cast<py::array>());
}));
}));
py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) {
auto &shape = csr_tensor.shape();
py::tuple dims(shape.size());
for (size_t i = 0; i < dims.size(); ++i) {
dims[i] = py::int_(shape[i]);
}
return dims;
}
REGISTER_PYBIND_DEFINE(
CSRTensor, ([](const py::module *m) {
// Define python CSRTensor class.
(void)py::class_<CSRTensor, std::shared_ptr<CSRTensor>>(*m, "CSRTensor")
.def(py::init([](const Tensor &indptr, const Tensor &indices, const Tensor &values, const py::tuple &shape) {
return std::make_shared<CSRTensor>(std::make_shared<Tensor>(indptr), std::make_shared<Tensor>(indices),
std::make_shared<Tensor>(values), GetShapeFromTuple(shape));
}),
py::arg("indptr"), py::arg("indices"), py::arg("values"), py::arg("shape"))
.def(py::init([](const CSRTensor &csr_tensor) { return std::make_shared<CSRTensor>(csr_tensor); }),
py::arg("input"))
.def_property_readonly("_shape", CSRTensorPy::GetPyTupleShape)
.def_property_readonly("_dtype", &CSRTensor::Dtype)
.def_property_readonly("_indptr", &CSRTensor::GetIndptr)
.def_property_readonly("_indices", &CSRTensor::GetIndices)
.def_property_readonly("_values", &CSRTensor::GetValues)
.def("__str__", &CSRTensor::ToString)
.def("__repr__", &CSRTensor::ToString);
}));
} // namespace tensor
} // namespace mindspore

View File

@ -118,6 +118,12 @@ class TensorPy {
static void FlushFromCache(const Tensor &tensor);
};
// CSRTensor python wrapper and adapter class.
class CSRTensorPy {
public:
static py::tuple GetPyTupleShape(const CSRTensor &csr_tensor);
};
} // namespace tensor
} // namespace mindspore

View File

@ -21,6 +21,7 @@
#include <memory>
#include <utility>
#include <stack>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
@ -28,6 +29,7 @@
#include "utils/convert_utils_base.h"
#include "utils/any.h"
#include "base/base_ref.h"
#include "base/core_ops.h"
#include "base/base.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
@ -77,6 +79,19 @@ std::vector<T> TensorValueToVector(const tensor::TensorPtr &tensor) {
void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors);
size_t CountValueNum(const ValueTuplePtr &value_tuple);
// sparse_attr_map converts CNode{kPrimSparseGetAttr, SparseTensor}
// to CNode{kPrimTupleGetItem, SparseTensor, int64_t(index)}, used
// in backend common optimization pass: sparse_process.cc
const std::unordered_map<std::string, int64_t> sparse_attr_map = {{prim::kPrimCSRTensorGetIndptr->name(), 0},
{prim::kPrimCSRTensorGetIndices->name(), 1},
{prim::kPrimCSRTensorGetValues->name(), 2},
{prim::kPrimCSRTensorGetDenseShape->name(), 3}};
// make_sparse_set records all make_sparse primitives, and tries to replace
// make_sparse to make_tuple, used in backend common optimization pass:
// sparse_process.cc
const std::unordered_set<std::string> make_sparse_set = {
{prim::kPrimMakeCSRTensor->name()}, {prim::kPrimMakeSparseTensor->name()}, {prim::kPrimMakeRowTensor->name()}};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_

View File

@ -30,17 +30,23 @@
#include "pipeline/jit/parse/parse_base.h"
#include "pipeline/jit/parse/resolve.h"
#include "ir/value.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "ir/param_info.h"
#include "pybind_api/ir/base_ref_py.h"
#include "ir/dtype/tensor_type.h"
#include "utils/ms_context.h"
#include "utils/convert_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
py::object BuiltinsToPyData(const Any &value);
py::object BuiltinsToPyData(const BaseRef &value);
py::object VectorToPyData(const Any &value);
py::object VectorRefToPyData(const VectorRef &value);
py::object VectorRefToPyData(const VectorRef &value_list);
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output);
// Wrap VectorRef to CSRTensor
py::object MakeCSRTensor(const VectorRef &value_list);
py::object TensorToPyData(const tensor::TensorPtr &tensor) {
MS_EXCEPTION_IF_NULL(tensor);
if (tensor->NeedWait()) {
@ -276,6 +282,19 @@ py::object AnyToPyData(const Any &value) {
return ret;
}
py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &output) {
py::object ret;
// If output value is a tuple, check if abstract is a SparseTensor in funcgraph output
if (utils::isa<VectorRef>(value)) {
MS_LOG(DEBUG) << "BaseRefToPyData, value is tuple: " << value.ToString();
auto vec_ref = utils::cast<VectorRef>(value);
ret = VectorRefToPyData(vec_ref, output);
} else {
ret = BaseRefToPyData(value);
}
return ret;
}
py::object BaseRefToPyData(const BaseRef &value) {
py::object ret;
MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
@ -383,6 +402,29 @@ py::object VectorRefToPyData(const VectorRef &value_list) {
return ret;
}
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output) {
MS_LOG(DEBUG) << "vector_ref";
// Current VectorRef reflects a SparseTensor type
if (output->isa<abstract::AbstractCSRTensor>()) {
return MakeCSRTensor(value_list);
}
py::object ret;
size_t value_size = value_list.size();
auto ref_tuple = py::tuple(value_size);
abstract::AbstractTuplePtr tuple_output = output->cast<abstract::AbstractTuplePtr>();
bool is_abstract_tuple = tuple_output != nullptr;
for (size_t i = 0; i < value_size; i++) {
if (!is_abstract_tuple || i >= tuple_output->size()) {
// Fall back to original process
ref_tuple[i] = BaseRefToPyData(value_list[i]);
} else {
ref_tuple[i] = BaseRefToPyData(value_list[i], (*tuple_output)[i]);
}
}
ret = ref_tuple;
return ret;
}
void SetValueRange(const AbstractBasePtr &tensor, const py::object &output) {
if (output.is_none()) {
return;
@ -551,4 +593,45 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
}
return false;
}
py::object MakeCSRTensor(const VectorRef &value_list) {
constexpr size_t kCSRTensorInputSize{4};
if (value_list.size() != kCSRTensorInputSize) {
MS_LOG(EXCEPTION) << "CSRTensor must have 4 inputs.";
}
using TensorPtr = tensor::TensorPtr;
using CSRTensor = tensor::CSRTensor;
TensorPtr indptr = utils::cast<TensorPtr>(value_list[0]);
TensorPtr indices = utils::cast<TensorPtr>(value_list[1]);
TensorPtr values = utils::cast<TensorPtr>(value_list[2]);
ValuePtr shape_ptr = utils::cast<ValuePtr>(value_list[3]);
ValueTuplePtr shape_tuple = shape_ptr->cast<ValueTuplePtr>();
ShapeVector shape{};
// CSRTensor shape is a tuple on GPU and CPU
if (shape_tuple) {
for (const auto &v : shape_tuple->value()) {
MS_EXCEPTION_IF_NULL(v);
ScalarPtr scalar = v->cast<ScalarPtr>();
MS_EXCEPTION_IF_NULL(scalar);
shape.push_back(GetValue<int64_t>(scalar));
}
// CSRTensor shape is a VectorRef(TensorPtr, TensorPtr) on Ascend
} else {
auto shape_ref = utils::cast<VectorRef>(value_list[3]);
MS_EXCEPTION_IF_NULL(shape_ref);
for (const auto &v : shape_ref) {
MS_EXCEPTION_IF_NULL(v);
auto tensorptr = utils::cast<TensorPtr>(v);
MS_EXCEPTION_IF_NULL(tensorptr);
if (tensorptr->DataDim() != 0) {
MS_LOG(EXCEPTION) << "Element in CSRTensor's shape must be scalar!";
}
shape.push_back(*(static_cast<int64_t *>(tensorptr->data_c())));
}
}
auto ref = py::tuple(1);
auto csr_tensor_ptr = std::make_shared<CSRTensor>(indptr, indices, values, shape);
ref[0] = csr_tensor_ptr;
return ref[0];
}
} // namespace mindspore

View File

@ -31,6 +31,7 @@ namespace py = pybind11;
namespace mindspore {
py::object AnyToPyData(const Any &value);
py::object BaseRefToPyData(const BaseRef &value);
py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &output);
py::object ValueToPyData(const ValuePtr &value);
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,

View File

@ -922,8 +922,14 @@ void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node,
MS_EXCEPTION_IF_NULL(output_node);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(output_position);
// The makeTuple node need expand and recurse.
if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) {
const PrimitiveSet expand_prims{
prim::kPrimMakeTuple,
prim::kPrimMakeCSRTensor,
prim::kPrimMakeSparseTensor,
prim::kPrimMakeRowTensor,
};
// The MakeTuple/MakeSaprse node need expand and recurse.
if (IsOneOfPrimitiveCNode(output_node, expand_prims)) {
auto make_tuple = output_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
VectorRef make_tuple_output;

View File

@ -24,7 +24,7 @@ from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
from .dump import set_dump
from .parameter import Parameter, ParameterTuple
from .seed import set_seed, get_seed
from .tensor import Tensor, RowTensor, SparseTensor
from .tensor import Tensor, RowTensor, SparseTensor, CSRTensor
# symbols from dtype
__all__ = [
@ -53,7 +53,7 @@ __all__ = [
]
__all__.extend([
"Tensor", "RowTensor", "SparseTensor", # tensor
"Tensor", "RowTensor", "SparseTensor", "CSRTensor", # tensor
'ms_function', # api
'Parameter', 'ParameterTuple', # parameter
"dtype",

View File

@ -26,7 +26,8 @@ from mindspore import context
from mindspore import log as logger
from mindspore._extends.remote import kernel_build_server
from .tensor import Tensor as MsTensor
from .._c_expression import generate_arguments_key, GraphExecutor_, Tensor, MetaTensor, PynativeExecutor_
from .tensor import CSRTensor as MsCSRTensor
from .._c_expression import generate_arguments_key, GraphExecutor_, Tensor, MetaTensor, CSRTensor, PynativeExecutor_
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
from ..parallel._ps_context import _is_role_pserver
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
@ -80,6 +81,8 @@ def _wrap_func(fn):
def _convert_data(data):
if isinstance(data, Tensor) and not isinstance(data, MsTensor):
return MsTensor(data)
if isinstance(data, CSRTensor) and not isinstance(data, MsCSRTensor):
return MsCSRTensor(csr_tensor=data)
if isinstance(data, tuple):
return tuple(_convert_data(x) for x in data)
if isinstance(data, list):

View File

@ -93,6 +93,7 @@ type_none = typing.TypeNone()
tensor = typing.TensorType()
index_slices = typing.RowTensorType()
sparse_tensor = typing.SparseTensorType()
csr_tensor = typing.CSRTensorType()
undetermined = typing.UndeterminedType()
function = typing.Function()

View File

@ -21,10 +21,11 @@ from mindspore.communication.management import get_rank, get_group_size
from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry
from .._c_expression import Tensor as Tensor_
from .._c_expression import CSRTensor as CSRTensor_
from .._c_expression import PynativeExecutor_
from .._checkparam import Validator as validator
__all__ = ['Tensor', 'RowTensor', 'SparseTensor']
__all__ = ['Tensor', 'RowTensor', 'SparseTensor', 'CSRTensor']
np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_, np.complex64, np.complex128)
@ -2240,6 +2241,97 @@ class SparseTensor:
def dense_shape(self):
return self.__dense_shape
class CSRTensor(CSRTensor_):
"""
Constructs a sparse tensor in CSR (Compressed Sparse Row) format, with specified
values indicated by `values` and row and column positions indicated by `indptr`
and `indices`.
Alternatively, CSRTensor can be initialized by passing another CSRTensor as input.
Currently this constructor can only be supported in PyNative Mode.
Note:
This is an experimental feature and is subjected to change.
Args:
indptr (Tensor): 1-D Tensor of size `shape[0] + 1`, which indicates the
start and end point for `values` in each row. Default: None. If provided,
must be :class:`mindspore.int16`, :class:`mindspore.int32` or :class:`mindspore.int64`.
indices (Tensor): 1-D Tensor, which has the same length as `values`. `indices`
indicates the which column `values` should be placed. Default: None. If provided,
must be :class:`mindspore.int16`, :class:`mindspore.int32` or :class:`mindspore.int64`.
values (Tensor): 1-D Tensor, which has the same length as `indices`. `values`
stores the data for CSRTensor. Default: None.
shape (Tuple): A tuple indicates the shape of the CSRTensor, its length must
be `2`, as only 2-D CSRTensor is currently supported, and `shape[0]` must
equal to `indptr[0] - 1`, which all equal to number of rows of the CSRTensor.
csr_tensor (CSRTensor): A CSRTensor object.
Outputs:
CSRTensor, with shape defined by `shape`, and dtype inferred from `value`.
Examples:
>>> import mindspore as ms
>>> from mindspore import Tensor, CSRTensor
>>> # initialize a csr_tensor with indptr, indices, values and shape
>>> indptr = Tensor([0, 1, 2])
>>> indices = Tensor([0, 1])
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> shape = (3, 4)
>>> csr_tensor = CSRTensor(indptr, indices, values, shape)
>>> # initialize a csr_tensor from another csr_tensor
>>> csr_tensor_2 = CSRTensor(csr_tensor=csr_tensor)
>>> # access a data member of CSRTensor
>>> print(indptr == csr_tensor.indptr)
>>> [ True True True]
"""
def __init__(self, indptr=None, indices=None, values=None, shape=None, csr_tensor=None):
self.init_finished = False
# Case 1: directly init a CSRTensor from another CSRTensor
if indptr is None and indices is None and values is None and shape is None:
if not isinstance(csr_tensor, (CSRTensor, CSRTensor_)):
raise TypeError("If only one input provided, it must be a CSRTensor.")
CSRTensor_.__init__(self, csr_tensor)
# Case 2: init a CSRTensor from indptr, indices, values and shape
else:
if (indptr is None or indices is None or values is None or shape is None):
raise TypeError("Inputs must follow: CSRTensor(indptr, indices, values, shape).")
if not (isinstance(indptr, Tensor) and isinstance(indices, Tensor) \
and isinstance(values, Tensor) and isinstance(shape, tuple)):
raise TypeError("Inputs must follow: CSRTensor(tensor, tensor, tensor, tuple).")
if len(shape) != 2 or shape[0] + 1 != indptr.shape[0] or shape[1] <= 0:
raise ValueError("Shape length should be 2, shape[0] should equal to indptr.shape[0] - 1")
if indptr.dtype not in (mstype.int16, mstype.int32, mstype.int64):
raise TypeError("indptr must have integer data type.")
if indices.dtype not in (mstype.int16, mstype.int32, mstype.int64):
raise TypeError("indices must have integer data type.")
CSRTensor_.__init__(self, indptr, indices, values, shape)
self.init_finished = True
def __repr__(self):
"""Avoid PyTest Segfault when CSRTensor is not initialized."""
if self.init_finished:
return CSRTensor_.__repr__(self)
return ''
@property
def indptr(self):
return Tensor(self._indptr)
@property
def indices(self):
return Tensor(self._indices)
@property
def values(self):
return Tensor(self._values)
@property
def shape(self):
return self._shape
def to_tuple(self):
return self.indptr, self.indices, self.values, self.shape
def _vm_compare(*args):
"""Implement `vm_compare` for tensor."""

View File

@ -1443,6 +1443,113 @@ std::string AbstractSparseTensor::ToString() const {
return buffer.str();
}
// CSRTensor
TypePtr AbstractCSRTensor::BuildType() const {
MS_EXCEPTION_IF_NULL(element());
TypePtr element_type = element()->BuildType();
return std::make_shared<CSRTensorType>(element_type);
}
AbstractBasePtr AbstractCSRTensor::Clone() const {
MS_EXCEPTION_IF_NULL(element());
auto clone = std::make_shared<AbstractCSRTensor>(element()->Clone());
ShapePtr shp = shape();
MS_EXCEPTION_IF_NULL(shp);
MS_EXCEPTION_IF_NULL(indptr_);
MS_EXCEPTION_IF_NULL(indices_);
MS_EXCEPTION_IF_NULL(values_);
MS_EXCEPTION_IF_NULL(dense_shape_);
auto indptr_clone = indptr_->Clone();
auto indices_clone = indices_->Clone();
auto value_clone = values_->Clone();
auto dense_clone = dense_shape_->Clone();
MS_EXCEPTION_IF_NULL(indptr_clone);
MS_EXCEPTION_IF_NULL(indices_clone);
MS_EXCEPTION_IF_NULL(value_clone);
MS_EXCEPTION_IF_NULL(dense_clone);
clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack());
clone->set_indptr(indptr_clone->cast<AbstractTensorPtr>());
clone->set_indices(indices_clone->cast<AbstractTensorPtr>());
clone->set_values(value_clone->cast<AbstractTensorPtr>());
clone->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
return clone;
}
AbstractBasePtr AbstractCSRTensor::Broaden() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractCSRTensor>(element()->Broaden());
auto shp = shape();
MS_EXCEPTION_IF_NULL(shp);
MS_EXCEPTION_IF_NULL(indptr_);
MS_EXCEPTION_IF_NULL(indices_);
MS_EXCEPTION_IF_NULL(values_);
MS_EXCEPTION_IF_NULL(dense_shape_);
auto indptr_clone = indptr_->Clone();
auto indices_clone = indices_->Clone();
auto value_clone = values_->Clone();
auto dense_clone = dense_shape_->Clone();
MS_EXCEPTION_IF_NULL(indptr_clone);
MS_EXCEPTION_IF_NULL(indices_clone);
MS_EXCEPTION_IF_NULL(value_clone);
MS_EXCEPTION_IF_NULL(dense_clone);
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
broaden->set_indptr(indptr_clone->cast<AbstractTensorPtr>());
broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
broaden->set_values(value_clone->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
return broaden;
}
AbstractBasePtr AbstractCSRTensor::BroadenWithShape() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractCSRTensor>(element()->Broaden());
auto this_shape = shape();
MS_EXCEPTION_IF_NULL(this_shape);
auto shp = this_shape->Clone();
MS_EXCEPTION_IF_NULL(shp);
shp->Broaden();
broaden->set_shape(shp);
broaden->set_value(kAnyValue);
MS_EXCEPTION_IF_NULL(indptr_);
MS_EXCEPTION_IF_NULL(indices_);
MS_EXCEPTION_IF_NULL(values_);
MS_EXCEPTION_IF_NULL(dense_shape_);
auto indptr_clone = indptr_->Clone();
auto indices_clone = indices_->Clone();
auto value_clone = values_->Clone();
auto dense_clone = dense_shape_->Clone();
MS_EXCEPTION_IF_NULL(indptr_clone);
MS_EXCEPTION_IF_NULL(indices_clone);
MS_EXCEPTION_IF_NULL(value_clone);
MS_EXCEPTION_IF_NULL(dense_clone);
broaden->set_indptr(indptr_clone->cast<AbstractTensorPtr>());
broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
broaden->set_values(value_clone->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
return broaden;
}
std::string AbstractCSRTensor::ToString() const {
std::ostringstream buffer;
BaseShapePtr shape_track = GetShapeTrack();
MS_EXCEPTION_IF_NULL(shape_track);
MS_EXCEPTION_IF_NULL(element());
auto value_track = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
MS_EXCEPTION_IF_NULL(indptr_);
MS_EXCEPTION_IF_NULL(indices_);
MS_EXCEPTION_IF_NULL(values_);
MS_EXCEPTION_IF_NULL(dense_shape_);
buffer << type_name() << "("
<< "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
<< ", indptr: " << indptr_->ToString() << ", indices: " << indices_->ToString() << ", values"
<< values_->ToString() << ", dense_shape: " << dense_shape_->ToString();
return buffer.str();
}
AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
if (!other->isa<AbstractUMonad>()) {

View File

@ -1485,6 +1485,38 @@ class MS_CORE_API AbstractSparseTensor final : public AbstractUndetermined {
AbstractTuplePtr dense_shape_;
};
// CSRTensor
class MS_CORE_API AbstractCSRTensor : public AbstractUndetermined {
public:
explicit AbstractCSRTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractUndetermined(element, shape) {}
AbstractCSRTensor(const TypePtr &element_type, const ShapeVector &shape)
: AbstractUndetermined(element_type, shape) {}
~AbstractCSRTensor() override = default;
MS_DECLARE_PARENT(AbstractCSRTensor, AbstractUndetermined)
const AbstractTensorPtr indptr() const { return indptr_; }
void set_indptr(const AbstractTensorPtr &indptr) { indptr_ = indptr; }
const AbstractTensorPtr indices() const { return indices_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
const AbstractTensorPtr values() const { return values_; }
void set_values(const AbstractTensorPtr &values) { values_ = values; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
std::string ToString() const override;
private:
AbstractTensorPtr indptr_;
AbstractTensorPtr indices_;
AbstractTensorPtr values_;
AbstractTuplePtr dense_shape_;
};
class AbstractMonad : public AbstractBase {
public:
~AbstractMonad() override = default;

View File

@ -149,6 +149,8 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -169,6 +171,17 @@ AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const
AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRTensorGetIndptr(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -83,6 +83,7 @@ ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
ABSTRACT_REPORT_NAME_TRAITS(Class)
ABSTRACT_REPORT_NAME_TRAITS(RowTensor)
ABSTRACT_REPORT_NAME_TRAITS(SparseTensor)
ABSTRACT_REPORT_NAME_TRAITS(CSRTensor)
ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
template <typename T>

View File

@ -395,6 +395,105 @@ AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, co
return sparse_tensor->dense_shape();
}
AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: three tensors and a tuple.
constexpr auto kMakeCSRInputNum = 4;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kMakeCSRInputNum);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
auto indices_dtype = indices->element()->BuildType();
if (!indices_dtype->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
}
auto indptr_shp = indptr->shape()->shape();
if (indptr_shp.size() != 1) {
MS_EXCEPTION(ValueError) << "Indptr must be a 1 dimension tensor, but got a " << indptr_shp.size()
<< " dimension tensor";
}
auto indices_shp = indices->shape()->shape();
if (indices_shp.size() != 1) {
MS_EXCEPTION(ValueError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
<< " dimension tensor";
}
auto values_shp = values->shape()->shape();
if (values_shp.size() != 1) {
MS_EXCEPTION(ValueError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
<< " dimension tensor";
}
if (indices_shp[0] != values_shp[0]) {
MS_EXCEPTION(ValueError) << "indices and values must have same size, but got: values length: " << values_shp[0]
<< ", indices length " << indices_shp[0];
}
for (const auto &elem_type : shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of shape must be Int, but got " << elem_type->ToString();
}
}
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
auto shp = shape_value->value();
ShapeVector shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(shape_vec), [](const ValuePtr &e) -> int64_t {
auto elem = GetValue<int64_t>(e);
return elem;
});
for (auto shape_elem : shape_vec) {
if (shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of shape must be positive, but got " << shape_value->ToString();
}
}
if (shape_vec[0] + 1 != indptr_shp[0]) {
MS_EXCEPTION(ValueError) << "indptr must have length (1 + shape[0]), but got: " << indptr_shp[0];
}
auto ret = std::make_shared<AbstractCSRTensor>(values->element()->BuildType(), shape_vec);
ret->set_indptr(indptr);
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(shape);
return ret;
}
template <typename T>
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
return CheckArg<T>(op_name, args_spec_list, 0);
}
AbstractBasePtr InferImplCSRTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->values());
return csr_tensor->values();
}
AbstractBasePtr InferImplCSRTensorGetIndptr(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->indptr());
return csr_tensor->indptr();
}
AbstractBasePtr InferImplCSRTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->indices());
return csr_tensor->indices();
}
AbstractBasePtr InferImplCSRTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->dense_shape());
return csr_tensor->dense_shape();
}
AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();

View File

@ -215,11 +215,16 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, nullptr, true}},
// RowTensor
{prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, nullptr, true}},
{prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, nullptr, true}},
{prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, nullptr, true}},
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, nullptr, true}},
{prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, nullptr, false}},
// CSRTensor
{prim::kPrimMakeCSRTensor, {InferImplMakeCSRTensor, nullptr, true}},
{prim::kPrimCSRTensorGetValues, {InferImplCSRTensorGetValues, nullptr, true}},
{prim::kPrimCSRTensorGetIndptr, {InferImplCSRTensorGetIndptr, nullptr, true}},
{prim::kPrimCSRTensorGetIndices, {InferImplCSRTensorGetIndices, nullptr, true}},
{prim::kPrimCSRTensorGetDenseShape, {InferImplCSRTensorGetDenseShape, nullptr, true}},
// Comm Ops
{prim::kPrimAllSwap, {InferImplAllSwap, nullptr, true}},
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, nullptr, true}},

View File

@ -474,6 +474,13 @@ inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitiv
inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
// CSRTensor
inline const PrimitivePtr kPrimMakeCSRTensor = std::make_shared<Primitive>("MakeCSRTensor");
inline const PrimitivePtr kPrimCSRTensorGetValues = std::make_shared<Primitive>("CSRTensorGetValues");
inline const PrimitivePtr kPrimCSRTensorGetIndptr = std::make_shared<Primitive>("CSRTensorGetIndptr");
inline const PrimitivePtr kPrimCSRTensorGetIndices = std::make_shared<Primitive>("CSRTensorGetIndices");
inline const PrimitivePtr kPrimCSRTensorGetDenseShape = std::make_shared<Primitive>("CSRTensorGetDenseShape");
// TensorList
inline const PrimitivePtr kPrimTensorListFromTensor = std::make_shared<Primitive>("TensorListFromTensor");
inline const PrimitivePtr kPrimTensorListReserve = std::make_shared<Primitive>("TensorListReserve");

View File

@ -191,4 +191,46 @@ bool SparseTensorType::operator==(const Type &other) const {
}
return *element_type_ == *other_elem_type;
}
TypePtr CSRTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<CSRTensorType>();
}
return std::make_shared<CSRTensorType>(element_type_->DeepCopy());
}
std::string CSRTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "CSRTensor";
}
return "CSRTensor[" + element_type_->ToReprString() + "]";
}
std::string CSRTensorType::ToString() const {
if (element_type_ == nullptr) {
return "CSRTensor";
}
return "CSRTensor[" + element_type_->ToString() + "]";
}
std::string CSRTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "CSRTensor";
}
return "CSRTensor[" + element_type_->DumpText() + "]";
}
bool CSRTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const CSRTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
} // namespace mindspore

View File

@ -190,6 +190,45 @@ class MS_CORE_API SparseTensorType final : public Object {
TypePtr element_type_;
};
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
/// \brief CSRTensorType defines interface for sparse tensor data type.
class MS_CORE_API CSRTensorType : public Object {
public:
/// \brief Default constructor for CSRTensorType.
CSRTensorType() : Object(kObjectTypeCSRTensorType, kObjectTypeUndeterminedType) {}
/// \brief Constructor for CSRTensorType.
///
/// \param[in] ele The element of CSRTensorType.
explicit CSRTensorType(const TypePtr &ele)
: Object(kObjectTypeCSRTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
/// \brief Destructor of CSRTensorType.
~CSRTensorType() override = default;
MS_DECLARE_PARENT(CSRTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeCSRTensorType; }
/// \brief Get the element of CSRTensorType object.
///
/// \return The element of CSRTensorType object.
const TypePtr element() const { return element_type_; }
/// \brief Set the element of CSRTensorType object.
///
/// \param[in] element_type Define the element type to be set.
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using CSRTensorTypePtr = std::shared_ptr<CSRTensorType>;
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_

View File

@ -50,6 +50,7 @@ static std::unordered_map<TypeId, std::string> g_type_2_lable{
{kObjectTypeTensorType, MS_TYPE2LABLE(kObjectTypeTensorType)},
{kObjectTypeRowTensorType, MS_TYPE2LABLE(kObjectTypeRowTensorType)},
{kObjectTypeSparseTensorType, MS_TYPE2LABLE(kObjectTypeSparseTensorType)},
{kObjectTypeCSRTensorType, MS_TYPE2LABLE(kObjectTypeCSRTensorType)},
{kObjectTypeUndeterminedType, MS_TYPE2LABLE(kObjectTypeUndeterminedType)},
{kObjectTypeClass, MS_TYPE2LABLE(kObjectTypeClass)},
{kObjectTypeDictionary, MS_TYPE2LABLE(kObjectTypeDictionary)},

View File

@ -86,13 +86,19 @@ enum TypeId : int {
//
// Monad Types
//
// Monad types is placed at the end of enum,
// in order to keep fit with the type of existing model on the lite side.
kMonadTypeBegin = kNumberTypeEnd,
kObjectTypeMonad,
kObjectTypeUMonad,
kObjectTypeIOMonad,
kMonadTypeEnd
kMonadTypeEnd,
//
// Sparse Types
//
// Sparse types is placed at the end of enum,
// in order to keep fit with the type of existing model on the lite side.
kSparseTypeBegin = kMonadTypeEnd,
kObjectTypeCSRTensorType,
kSparseTypeEnd
};
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_

View File

@ -191,6 +191,23 @@ TypePtr SparseTensorStrToType(const std::string &type_name) {
return std::make_shared<SparseTensorType>(element_type);
}
TypePtr CSRTensorStrToType(const std::string &type_name) {
if (type_name == "CSRTensor") {
return std::make_shared<CSRTensorType>();
}
auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
auto element_str = type_name.substr(start, end - start);
auto element_type = StringToType(element_str);
if (element_type == nullptr) {
return nullptr;
}
return std::make_shared<CSRTensorType>(element_type);
}
TypePtr UndeterminedStrToType(const std::string &type_name) {
if (type_name == "Undetermined") {
return std::make_shared<UndeterminedType>();
@ -330,6 +347,7 @@ TypePtr GetTypeByStringStarts(const std::string &type_name) {
{"Undetermined", [](const std::string &type_name) -> TypePtr { return UndeterminedStrToType(type_name); }},
{"RowTensor", [](const std::string &type_name) -> TypePtr { return RowTensorStrToType(type_name); }},
{"SparseTensor", [](const std::string &type_name) -> TypePtr { return SparseTensorStrToType(type_name); }},
{"CSRTensor", [](const std::string &type_name) -> TypePtr { return CSRTensorStrToType(type_name); }},
{"List", [](const std::string &type_name) -> TypePtr { return ListStrToType(type_name); }},
{"Tuple", [](const std::string &type_name) -> TypePtr { return TupleStrToType(type_name); }},
{"Function", [](const std::string &type_name) -> TypePtr { return FunctionStrToType(type_name); }}};

View File

@ -673,5 +673,31 @@ TypeId Tensor::set_data_type(const TypeId data_type) {
}
return data_type;
}
CSRTensor::CSRTensor(const TensorPtr indptr, const TensorPtr indices, const TensorPtr values, const ShapeVector &shape)
: indptr_(indptr), indices_(indices), values_(values), shape_(shape) {}
bool CSRTensor::operator==(const CSRTensor &csr_tensor) const { return (&csr_tensor == this); }
std::string CSRTensor::ToString() const {
std::ostringstream buf;
MS_EXCEPTION_IF_NULL(values_);
MS_EXCEPTION_IF_NULL(indices_);
MS_EXCEPTION_IF_NULL(indptr_);
auto dtype = values_->Dtype();
buf << "CSRTensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", indptr=";
buf << indptr_->ToString() << ", indices=" << indices_->ToString() << ", values=";
buf << values_->ToString() << ")";
return buf.str();
}
abstract::AbstractBasePtr CSRTensor::ToAbstract() {
auto dtype = values_->Dtype();
if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
}
auto abs_csr_tensor = std::make_shared<abstract::AbstractCSRTensor>(dtype, shape_);
return abs_csr_tensor;
}
} // namespace tensor
} // namespace mindspore

View File

@ -547,6 +547,63 @@ class MS_CORE_API Tensor final : public MetaTensor {
};
using TensorPtr = std::shared_ptr<Tensor>;
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
// CSRTensor entity class
class MS_CORE_API CSRTensor : public MetaTensor {
public:
abstract::AbstractBasePtr ToAbstract() override;
/// \brief Create CSRTensor with given data type from another tensor.
///
/// \param[in] indptr [Tensor] The indices pointer.
/// \param[in] indices [Tensor] The indices.
/// \param[in] values [Tensor] The values.
/// \param[in] shape The shape represented by ShapeVector of the CSRensor.
CSRTensor(const TensorPtr indptr, const TensorPtr indices, const TensorPtr values, const ShapeVector &shape);
/// Destructor of CSRTensor.
~CSRTensor() override = default;
/// \brief Gets CSRTensor's indptr.
///
/// \return [TensorPtr] The indices pointer.
TensorPtr GetIndptr() { return indptr_; }
/// \brief Gets CSRTensor's indices.
///
/// \return [TensorPtr] The indices.
TensorPtr GetIndices() { return indices_; }
/// \brief Gets CSRTensor's values.
///
/// \return [TensorPtr] The values.
TensorPtr GetValues() { return values_; }
/// \brief Gets CSRTensor's shape.
///
/// \return [ShapeVector] The shape of the tensor.
const ShapeVector &shape() const { return shape_; }
/// \brief Compare two tensor objects to see if they have same data type, shape and data address.
///
/// \param[in] tensor The Tensor object to be compared.
/// \return True if having same type, shape and data address, otherwise false.
bool operator==(const CSRTensor &csr_tensor) const;
/// \brief Get display information of this Tensor.
///
/// \return The display information of this Tensor.
std::string ToString() const override;
TypePtr Dtype() const { return values_->Dtype(); }
private:
TensorPtr indptr_;
TensorPtr indices_;
TensorPtr values_;
ShapeVector shape_{};
};
using CSRTensorPtr = std::shared_ptr<CSRTensor>;
} // namespace tensor
} // namespace mindspore

View File

@ -394,6 +394,12 @@ sparse_tensor_get_values = Primitive('SparseTensorGetValues')
sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
make_csr_tensor = Primitive('MakeCSRTensor')
csr_tensor_get_values = Primitive('CSRTensorGetValues')
csr_tensor_get_indices = Primitive('CSRTensorGetIndices')
csr_tensor_get_indptr = Primitive('CSRTensorGetIndptr')
csr_tensor_get_shape = Primitive('CSRTensorGetDenseShape')
tensor_operator_registry.register('all', P.ReduceAll)
tensor_operator_registry.register('any', P.ReduceAny)
tensor_operator_registry.register('abs', P.Abs)

View File

@ -0,0 +1,84 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""smoke tests for CSR operations"""
import pytest
from mindspore import Tensor, CSRTensor, ms_function
from mindspore.common import dtype as mstype
def compare_csr(csr1, csr2):
assert isinstance(csr1, CSRTensor)
assert isinstance(csr2, CSRTensor)
assert (csr1.indptr.asnumpy() == csr2.indptr.asnumpy()).all()
assert (csr1.indices.asnumpy() == csr2.indices.asnumpy()).all()
assert (csr1.values.asnumpy() == csr2.values.asnumpy()).all()
assert csr1.shape == csr2.shape
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_make_csr():
"""
Feature: Test CSRTensor Constructor in Graph and PyNative.
Description: Test CSRTensor(indptr, indices, values, shape) and CSRTensor(CSRTensor)
Expectation: Success.
"""
indptr = Tensor([0, 1, 2])
indices = Tensor([0, 1])
values = Tensor([1, 2], dtype=mstype.float32)
shape = (2, 6)
def test_pynative():
return CSRTensor(indptr, indices, values, shape)
test_graph = ms_function(test_pynative)
csr1 = test_pynative()
csr2 = test_graph()
compare_csr(csr1, csr2)
csr3 = CSRTensor(csr_tensor=csr2)
compare_csr(csr3, csr2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_csr_attr():
"""
Feature: Test CSRTensor GetAttr in Graph and PyNative.
Description: Test CSRTensor.indptr, CSRTensor.indices, CSRTensor.values, CSRTensor.shape.
Expectation: Success.
"""
indptr = Tensor([0, 1, 2])
indices = Tensor([0, 1])
values = Tensor([1, 2], dtype=mstype.float32)
shape = (2, 6)
def test_pynative():
csr = CSRTensor(indptr, indices, values, shape)
return csr.indptr, csr.indices, csr.values, csr.shape
test_graph = ms_function(test_pynative)
csr1_tuple = test_pynative()
csr2_tuple = test_graph()
csr1 = CSRTensor(*csr1_tuple)
csr2 = CSRTensor(*csr2_tuple)
compare_csr(csr1, csr2)