forked from mindspore-Ecosystem/mindspore
support sparse tensor frontend
This commit is contained in:
parent
6bdd38399a
commit
72db8e4d3f
|
@ -18,7 +18,7 @@
|
|||
import ast
|
||||
import math
|
||||
|
||||
from mindspore import RowTensor, SparseTensor
|
||||
from mindspore import RowTensor, SparseTensor, CSRTensor
|
||||
from mindspore.ops import functional as F, composite as C
|
||||
from mindspore.ops.composite import multitype_ops
|
||||
from mindspore._c_expression import security
|
||||
|
@ -140,6 +140,7 @@ convert_object_map = {
|
|||
# user defined
|
||||
RowTensor: F.make_row_tensor,
|
||||
SparseTensor: F.make_sparse_tensor,
|
||||
CSRTensor: F.make_csr_tensor
|
||||
}
|
||||
|
||||
if not security.enable_security():
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "backend/optimizer/pass/reduce_sum_optimizer.h"
|
||||
#include "backend/optimizer/pass/add_dynamic_shape_attr.h"
|
||||
#include "backend/optimizer/pass/add_akg_kernel_attrs.h"
|
||||
#include "backend/optimizer/pass/sparse_process.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
|
@ -51,6 +52,7 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
#endif
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto common_pm = std::make_shared<PassManager>("common_pm");
|
||||
common_pm->AddPass(std::make_shared<SparseProcess>());
|
||||
common_pm->AddPass(std::make_shared<AddDynamicShapeAttr>());
|
||||
common_pm->AddPass(std::make_shared<ReduceSumOptimizer>());
|
||||
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/common/const_input_to_attr_registry.h"
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/pass/sparse_process.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::string prim_name = prim->name();
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
// cnode is a MakeSparse node
|
||||
if (make_sparse_set.find(prim_name) != make_sparse_set.end()) {
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
// Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items;
|
||||
(void)inputs.insert(inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
auto new_node = cnode->func_graph()->NewCNode(inputs);
|
||||
auto abs_sparse = dyn_cast<abstract::AbstractCSRTensor>(node->abstract());
|
||||
std::vector<AbstractBasePtr> abstract_list{abs_sparse->indptr(), abs_sparse->indices(), abs_sparse->values(),
|
||||
abs_sparse->dense_shape()};
|
||||
auto abs_res = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
new_node->set_abstract(abs_res);
|
||||
new_node->set_scope(cnode->scope());
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(cnode, new_node);
|
||||
}
|
||||
return new_node;
|
||||
// cnode is a SparseGetAttr node
|
||||
} else if (sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
|
||||
const auto &inputs = cnode->inputs();
|
||||
// Inputs should be [sparse_getattr, sparse]
|
||||
constexpr size_t sparse_index = 1;
|
||||
AnfNodePtr sparse = inputs[sparse_index];
|
||||
MS_EXCEPTION_IF_NULL(sparse);
|
||||
int64_t index = sparse_attr_map.at(prim_name);
|
||||
auto cons_node = NewValueNode(index);
|
||||
AbstractBasePtr aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
|
||||
cons_node->set_abstract(aptr);
|
||||
auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node}, func_graph);
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SPARSE_PROCESS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SPARSE_PROCESS_H_
|
||||
#include <string>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SparseProcess : public PatternProcessPass {
|
||||
public:
|
||||
explicit SparseProcess(bool multigraph = true) : PatternProcessPass("sparse_process", multigraph) {}
|
||||
~SparseProcess() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
const size_t kAnfPrimitiveIndex = 0;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SPARSE_PROCESS_H_
|
|
@ -334,9 +334,14 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputByCallNode(const K
|
|||
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) {
|
||||
std::vector<KernelWithIndex> ret;
|
||||
std::vector<KernelWithIndex> ret_empty;
|
||||
|
||||
// The makeTuple node need expand and recurse.
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
const PrimitiveSet expand_prims{
|
||||
prim::kPrimMakeTuple,
|
||||
prim::kPrimMakeCSRTensor,
|
||||
prim::kPrimMakeSparseTensor,
|
||||
prim::kPrimMakeRowTensor,
|
||||
};
|
||||
// The MakeTuple/MakeSparse node need expand and recurse.
|
||||
if (IsOneOfPrimitiveCNode(node, expand_prims)) {
|
||||
auto make_tuple = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
|
||||
|
@ -385,8 +390,8 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
|||
AnfAlgo::VisitKernelWithReturnType(node, i, false, {prim::kPrimMakeTuple, prim::kPrimUpdateState});
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
|
||||
// The makeTuple node need recurse.
|
||||
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) {
|
||||
// The MakeTuple/MakeSparse node need recurse.
|
||||
if (IsOneOfPrimitiveCNode(output_with_index.first, expand_prims)) {
|
||||
auto output_vector = GetAllOutputWithIndex(output_with_index.first);
|
||||
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
|
||||
continue;
|
||||
|
|
|
@ -45,6 +45,7 @@ class ProtoExporter {
|
|||
std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const std::map<AnfNodePtr, size_t> &apply_map,
|
||||
std::map<AnfNodePtr, size_t> *const_map_ptr);
|
||||
void SetValueToProtoBasicTypes(const ValuePtr &attr_value, irpb::ValueProto *value_proto);
|
||||
void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto);
|
||||
void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto);
|
||||
void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto);
|
||||
|
@ -112,11 +113,12 @@ void CheckIfValidType(const TypePtr &type) {
|
|||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->isa<Problem>()) {
|
||||
MS_LOG(WARNING) << "The type: " << type->type_name();
|
||||
return;
|
||||
}
|
||||
if (!(type->isa<Number>() || type->isa<TensorType>() || type->isa<Tuple>() || type->isa<TypeType>() ||
|
||||
type->isa<List>() || type->isa<TypeAnything>() || type->isa<RefKeyType>() || type->isa<RefType>() ||
|
||||
type->isa<Function>() || type->isa<TypeNone>() || type->isa<Problem>() || type->isa<String>() ||
|
||||
type->isa<RowTensorType>() || type->isa<UndeterminedType>() || type->isa<SparseTensorType>() ||
|
||||
type->isa<Function>() || type->isa<TypeNone>() || type->isa<String>() || type->isa<RowTensorType>() ||
|
||||
type->isa<CSRTensorType>() || type->isa<UndeterminedType>() || type->isa<SparseTensorType>() ||
|
||||
type->isa<SymbolicKeyType>() || type->isa<MonadType>())) {
|
||||
MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
|
||||
}
|
||||
|
@ -126,12 +128,12 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
|
|||
if (type_proto == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (type != nullptr) {
|
||||
CheckIfValidType(type);
|
||||
}
|
||||
if (type == nullptr) {
|
||||
type_proto->set_data_type(irpb::DT_UNDEFINED);
|
||||
} else if (type->isa<Number>()) {
|
||||
return;
|
||||
}
|
||||
CheckIfValidType(type);
|
||||
if (type->isa<Number>()) {
|
||||
type_proto->set_data_type(GetNumberDataType(type));
|
||||
} else if (type->isa<TensorType>()) {
|
||||
TypePtr elem_type = dyn_cast<TensorType>(type)->element();
|
||||
|
@ -179,11 +181,7 @@ void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *t
|
|||
SetNodeOutputType(node->Type(), node->Shape(), type_proto);
|
||||
}
|
||||
|
||||
void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) {
|
||||
if (val == nullptr || value_proto == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
void ProtoExporter::SetValueToProtoBasicTypes(const ValuePtr &val, irpb::ValueProto *value_proto) {
|
||||
if (val->isa<StringImm>()) {
|
||||
const StringImmPtr &value = dyn_cast<StringImm>(val);
|
||||
value_proto->set_dtype(irpb::DT_STRING);
|
||||
|
@ -202,7 +200,17 @@ void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value
|
|||
} else if (val->isa<Float>()) {
|
||||
value_proto->set_dtype(irpb::DT_TYPE);
|
||||
value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_FLOAT);
|
||||
} else if (val->isa<ValueSequeue>()) {
|
||||
}
|
||||
}
|
||||
|
||||
void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) {
|
||||
if (val == nullptr || value_proto == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
SetValueToProtoBasicTypes(val, value_proto);
|
||||
|
||||
if (val->isa<ValueSequeue>()) {
|
||||
SetSequenceToProto(dyn_cast<ValueSequeue>(val), value_proto);
|
||||
} else if (val->isa<None>()) {
|
||||
value_proto->set_dtype(irpb::DT_NONE);
|
||||
|
|
|
@ -1138,9 +1138,11 @@ py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_o
|
|||
MS_LOG(INFO) << "VM loop size " << vm_loop << ", loopsink size " << vm_loop;
|
||||
py::object ret;
|
||||
MS_LOG(DEBUG) << "Eval run" << backend;
|
||||
auto output = execute_info->func_graph->output()->abstract();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
for (int64_t i = 0; i < vm_loop; i++) {
|
||||
BaseRef value = (*run)(execute_info->arg_list);
|
||||
ret = BaseRefToPyData(value);
|
||||
ret = BaseRefToPyData(value, output);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Run end";
|
||||
return ret;
|
||||
|
|
|
@ -241,6 +241,13 @@ BuiltInTypeMap &GetAttrMap() {
|
|||
{"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices
|
||||
{"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape
|
||||
}},
|
||||
{kObjectTypeCSRTensorType,
|
||||
{
|
||||
{"indptr", prim::kPrimCSRTensorGetIndptr}, // F.csr_tensor_get_indptr
|
||||
{"values", prim::kPrimCSRTensorGetValues}, // F.csr_tensor_get_values
|
||||
{"indices", prim::kPrimCSRTensorGetIndices}, // F.csr_tensor_get_indices
|
||||
{"shape", prim::kPrimCSRTensorGetDenseShape}, // F.csr_tensor_get_shape
|
||||
}},
|
||||
};
|
||||
return attr_map;
|
||||
}
|
||||
|
|
|
@ -462,6 +462,11 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|||
dic[ATTR_SHAPE] = arg->shape()->shape();
|
||||
dic[ATTR_DTYPE] = arg->BuildType();
|
||||
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
|
||||
} else if (abs_base->isa<AbstractCSRTensor>()) {
|
||||
auto arg = dyn_cast<AbstractCSRTensor>(abs_base);
|
||||
dic[ATTR_SHAPE] = arg->shape()->shape();
|
||||
dic[ATTR_DTYPE] = arg->BuildType();
|
||||
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
|
||||
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
|
||||
ShapeVector shape;
|
||||
dic[ATTR_SHAPE] = shape;
|
||||
|
|
|
@ -30,6 +30,7 @@ namespace mindspore {
|
|||
namespace validator {
|
||||
using mindspore::abstract::AbstractBase;
|
||||
using mindspore::abstract::AbstractClass;
|
||||
using mindspore::abstract::AbstractCSRTensor;
|
||||
using mindspore::abstract::AbstractError;
|
||||
using mindspore::abstract::AbstractFunction;
|
||||
using mindspore::abstract::AbstractJTagged;
|
||||
|
@ -114,11 +115,12 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
|||
MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString();
|
||||
return;
|
||||
}
|
||||
bool is_legal_abstract =
|
||||
abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() || abstract->isa<AbstractTuple>() ||
|
||||
abstract->isa<AbstractList>() || abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
|
||||
abstract->isa<AbstractSparseTensor>() || abstract->isa<abstract::AbstractRefKey>() ||
|
||||
abstract->isa<AbstractRef>() || abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
|
||||
bool is_legal_abstract = abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() ||
|
||||
abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
|
||||
abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
|
||||
abstract->isa<AbstractSparseTensor>() || abstract->isa<AbstractCSRTensor>() ||
|
||||
abstract->isa<abstract::AbstractRefKey>() || abstract->isa<AbstractRef>() ||
|
||||
abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
|
||||
if (is_legal_abstract) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -154,6 +154,7 @@ REGISTER_PYBIND_DEFINE(
|
|||
(void)py::class_<RowTensorType, Type, std::shared_ptr<RowTensorType>>(m_sub, "RowTensorType").def(py::init());
|
||||
(void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType")
|
||||
.def(py::init());
|
||||
(void)py::class_<CSRTensorType, Type, std::shared_ptr<CSRTensorType>>(m_sub, "CSRTensorType").def(py::init());
|
||||
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
|
||||
.def(py::init());
|
||||
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
|
||||
|
|
|
@ -640,5 +640,34 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
return TensorPy::MakeTensor(t[0].cast<py::array>());
|
||||
}));
|
||||
}));
|
||||
|
||||
py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) {
|
||||
auto &shape = csr_tensor.shape();
|
||||
py::tuple dims(shape.size());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dims[i] = py::int_(shape[i]);
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(
|
||||
CSRTensor, ([](const py::module *m) {
|
||||
// Define python CSRTensor class.
|
||||
(void)py::class_<CSRTensor, std::shared_ptr<CSRTensor>>(*m, "CSRTensor")
|
||||
.def(py::init([](const Tensor &indptr, const Tensor &indices, const Tensor &values, const py::tuple &shape) {
|
||||
return std::make_shared<CSRTensor>(std::make_shared<Tensor>(indptr), std::make_shared<Tensor>(indices),
|
||||
std::make_shared<Tensor>(values), GetShapeFromTuple(shape));
|
||||
}),
|
||||
py::arg("indptr"), py::arg("indices"), py::arg("values"), py::arg("shape"))
|
||||
.def(py::init([](const CSRTensor &csr_tensor) { return std::make_shared<CSRTensor>(csr_tensor); }),
|
||||
py::arg("input"))
|
||||
.def_property_readonly("_shape", CSRTensorPy::GetPyTupleShape)
|
||||
.def_property_readonly("_dtype", &CSRTensor::Dtype)
|
||||
.def_property_readonly("_indptr", &CSRTensor::GetIndptr)
|
||||
.def_property_readonly("_indices", &CSRTensor::GetIndices)
|
||||
.def_property_readonly("_values", &CSRTensor::GetValues)
|
||||
.def("__str__", &CSRTensor::ToString)
|
||||
.def("__repr__", &CSRTensor::ToString);
|
||||
}));
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -118,6 +118,12 @@ class TensorPy {
|
|||
|
||||
static void FlushFromCache(const Tensor &tensor);
|
||||
};
|
||||
|
||||
// CSRTensor python wrapper and adapter class.
|
||||
class CSRTensorPy {
|
||||
public:
|
||||
static py::tuple GetPyTupleShape(const CSRTensor &csr_tensor);
|
||||
};
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
@ -28,6 +29,7 @@
|
|||
#include "utils/convert_utils_base.h"
|
||||
#include "utils/any.h"
|
||||
#include "base/base_ref.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "base/base.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
@ -77,6 +79,19 @@ std::vector<T> TensorValueToVector(const tensor::TensorPtr &tensor) {
|
|||
void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors);
|
||||
|
||||
size_t CountValueNum(const ValueTuplePtr &value_tuple);
|
||||
|
||||
// sparse_attr_map converts CNode{kPrimSparseGetAttr, SparseTensor}
|
||||
// to CNode{kPrimTupleGetItem, SparseTensor, int64_t(index)}, used
|
||||
// in backend common optimization pass: sparse_process.cc
|
||||
const std::unordered_map<std::string, int64_t> sparse_attr_map = {{prim::kPrimCSRTensorGetIndptr->name(), 0},
|
||||
{prim::kPrimCSRTensorGetIndices->name(), 1},
|
||||
{prim::kPrimCSRTensorGetValues->name(), 2},
|
||||
{prim::kPrimCSRTensorGetDenseShape->name(), 3}};
|
||||
// make_sparse_set records all make_sparse primitives, and tries to replace
|
||||
// make_sparse to make_tuple, used in backend common optimization pass:
|
||||
// sparse_process.cc
|
||||
const std::unordered_set<std::string> make_sparse_set = {
|
||||
{prim::kPrimMakeCSRTensor->name()}, {prim::kPrimMakeSparseTensor->name()}, {prim::kPrimMakeRowTensor->name()}};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_
|
||||
|
|
|
@ -30,17 +30,23 @@
|
|||
#include "pipeline/jit/parse/parse_base.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "pybind_api/ir/base_ref_py.h"
|
||||
#include "ir/dtype/tensor_type.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
py::object BuiltinsToPyData(const Any &value);
|
||||
py::object BuiltinsToPyData(const BaseRef &value);
|
||||
py::object VectorToPyData(const Any &value);
|
||||
py::object VectorRefToPyData(const VectorRef &value);
|
||||
|
||||
py::object VectorRefToPyData(const VectorRef &value_list);
|
||||
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output);
|
||||
// Wrap VectorRef to CSRTensor
|
||||
py::object MakeCSRTensor(const VectorRef &value_list);
|
||||
py::object TensorToPyData(const tensor::TensorPtr &tensor) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (tensor->NeedWait()) {
|
||||
|
@ -276,6 +282,19 @@ py::object AnyToPyData(const Any &value) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &output) {
|
||||
py::object ret;
|
||||
// If output value is a tuple, check if abstract is a SparseTensor in funcgraph output
|
||||
if (utils::isa<VectorRef>(value)) {
|
||||
MS_LOG(DEBUG) << "BaseRefToPyData, value is tuple: " << value.ToString();
|
||||
auto vec_ref = utils::cast<VectorRef>(value);
|
||||
ret = VectorRefToPyData(vec_ref, output);
|
||||
} else {
|
||||
ret = BaseRefToPyData(value);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
py::object BaseRefToPyData(const BaseRef &value) {
|
||||
py::object ret;
|
||||
MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
|
||||
|
@ -383,6 +402,29 @@ py::object VectorRefToPyData(const VectorRef &value_list) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output) {
|
||||
MS_LOG(DEBUG) << "vector_ref";
|
||||
// Current VectorRef reflects a SparseTensor type
|
||||
if (output->isa<abstract::AbstractCSRTensor>()) {
|
||||
return MakeCSRTensor(value_list);
|
||||
}
|
||||
py::object ret;
|
||||
size_t value_size = value_list.size();
|
||||
auto ref_tuple = py::tuple(value_size);
|
||||
abstract::AbstractTuplePtr tuple_output = output->cast<abstract::AbstractTuplePtr>();
|
||||
bool is_abstract_tuple = tuple_output != nullptr;
|
||||
for (size_t i = 0; i < value_size; i++) {
|
||||
if (!is_abstract_tuple || i >= tuple_output->size()) {
|
||||
// Fall back to original process
|
||||
ref_tuple[i] = BaseRefToPyData(value_list[i]);
|
||||
} else {
|
||||
ref_tuple[i] = BaseRefToPyData(value_list[i], (*tuple_output)[i]);
|
||||
}
|
||||
}
|
||||
ret = ref_tuple;
|
||||
return ret;
|
||||
}
|
||||
|
||||
void SetValueRange(const AbstractBasePtr &tensor, const py::object &output) {
|
||||
if (output.is_none()) {
|
||||
return;
|
||||
|
@ -551,4 +593,45 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
py::object MakeCSRTensor(const VectorRef &value_list) {
|
||||
constexpr size_t kCSRTensorInputSize{4};
|
||||
if (value_list.size() != kCSRTensorInputSize) {
|
||||
MS_LOG(EXCEPTION) << "CSRTensor must have 4 inputs.";
|
||||
}
|
||||
using TensorPtr = tensor::TensorPtr;
|
||||
using CSRTensor = tensor::CSRTensor;
|
||||
TensorPtr indptr = utils::cast<TensorPtr>(value_list[0]);
|
||||
TensorPtr indices = utils::cast<TensorPtr>(value_list[1]);
|
||||
TensorPtr values = utils::cast<TensorPtr>(value_list[2]);
|
||||
ValuePtr shape_ptr = utils::cast<ValuePtr>(value_list[3]);
|
||||
ValueTuplePtr shape_tuple = shape_ptr->cast<ValueTuplePtr>();
|
||||
ShapeVector shape{};
|
||||
// CSRTensor shape is a tuple on GPU and CPU
|
||||
if (shape_tuple) {
|
||||
for (const auto &v : shape_tuple->value()) {
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
ScalarPtr scalar = v->cast<ScalarPtr>();
|
||||
MS_EXCEPTION_IF_NULL(scalar);
|
||||
shape.push_back(GetValue<int64_t>(scalar));
|
||||
}
|
||||
// CSRTensor shape is a VectorRef(TensorPtr, TensorPtr) on Ascend
|
||||
} else {
|
||||
auto shape_ref = utils::cast<VectorRef>(value_list[3]);
|
||||
MS_EXCEPTION_IF_NULL(shape_ref);
|
||||
for (const auto &v : shape_ref) {
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
auto tensorptr = utils::cast<TensorPtr>(v);
|
||||
MS_EXCEPTION_IF_NULL(tensorptr);
|
||||
if (tensorptr->DataDim() != 0) {
|
||||
MS_LOG(EXCEPTION) << "Element in CSRTensor's shape must be scalar!";
|
||||
}
|
||||
shape.push_back(*(static_cast<int64_t *>(tensorptr->data_c())));
|
||||
}
|
||||
}
|
||||
auto ref = py::tuple(1);
|
||||
auto csr_tensor_ptr = std::make_shared<CSRTensor>(indptr, indices, values, shape);
|
||||
ref[0] = csr_tensor_ptr;
|
||||
return ref[0];
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,6 +31,7 @@ namespace py = pybind11;
|
|||
namespace mindspore {
|
||||
py::object AnyToPyData(const Any &value);
|
||||
py::object BaseRefToPyData(const BaseRef &value);
|
||||
py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &output);
|
||||
py::object ValueToPyData(const ValuePtr &value);
|
||||
|
||||
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
|
||||
|
|
|
@ -922,8 +922,14 @@ void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node,
|
|||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(output_position);
|
||||
// The makeTuple node need expand and recurse.
|
||||
if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) {
|
||||
const PrimitiveSet expand_prims{
|
||||
prim::kPrimMakeTuple,
|
||||
prim::kPrimMakeCSRTensor,
|
||||
prim::kPrimMakeSparseTensor,
|
||||
prim::kPrimMakeRowTensor,
|
||||
};
|
||||
// The MakeTuple/MakeSaprse node need expand and recurse.
|
||||
if (IsOneOfPrimitiveCNode(output_node, expand_prims)) {
|
||||
auto make_tuple = output_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
VectorRef make_tuple_output;
|
||||
|
|
|
@ -24,7 +24,7 @@ from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
|
|||
from .dump import set_dump
|
||||
from .parameter import Parameter, ParameterTuple
|
||||
from .seed import set_seed, get_seed
|
||||
from .tensor import Tensor, RowTensor, SparseTensor
|
||||
from .tensor import Tensor, RowTensor, SparseTensor, CSRTensor
|
||||
|
||||
# symbols from dtype
|
||||
__all__ = [
|
||||
|
@ -53,7 +53,7 @@ __all__ = [
|
|||
]
|
||||
|
||||
__all__.extend([
|
||||
"Tensor", "RowTensor", "SparseTensor", # tensor
|
||||
"Tensor", "RowTensor", "SparseTensor", "CSRTensor", # tensor
|
||||
'ms_function', # api
|
||||
'Parameter', 'ParameterTuple', # parameter
|
||||
"dtype",
|
||||
|
|
|
@ -26,7 +26,8 @@ from mindspore import context
|
|||
from mindspore import log as logger
|
||||
from mindspore._extends.remote import kernel_build_server
|
||||
from .tensor import Tensor as MsTensor
|
||||
from .._c_expression import generate_arguments_key, GraphExecutor_, Tensor, MetaTensor, PynativeExecutor_
|
||||
from .tensor import CSRTensor as MsCSRTensor
|
||||
from .._c_expression import generate_arguments_key, GraphExecutor_, Tensor, MetaTensor, CSRTensor, PynativeExecutor_
|
||||
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
|
||||
from ..parallel._ps_context import _is_role_pserver
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
|
||||
|
@ -80,6 +81,8 @@ def _wrap_func(fn):
|
|||
def _convert_data(data):
|
||||
if isinstance(data, Tensor) and not isinstance(data, MsTensor):
|
||||
return MsTensor(data)
|
||||
if isinstance(data, CSRTensor) and not isinstance(data, MsCSRTensor):
|
||||
return MsCSRTensor(csr_tensor=data)
|
||||
if isinstance(data, tuple):
|
||||
return tuple(_convert_data(x) for x in data)
|
||||
if isinstance(data, list):
|
||||
|
|
|
@ -93,6 +93,7 @@ type_none = typing.TypeNone()
|
|||
tensor = typing.TensorType()
|
||||
index_slices = typing.RowTensorType()
|
||||
sparse_tensor = typing.SparseTensorType()
|
||||
csr_tensor = typing.CSRTensorType()
|
||||
undetermined = typing.UndeterminedType()
|
||||
|
||||
function = typing.Function()
|
||||
|
|
|
@ -21,10 +21,11 @@ from mindspore.communication.management import get_rank, get_group_size
|
|||
from . import dtype as mstype
|
||||
from ._register_for_tensor import tensor_operator_registry
|
||||
from .._c_expression import Tensor as Tensor_
|
||||
from .._c_expression import CSRTensor as CSRTensor_
|
||||
from .._c_expression import PynativeExecutor_
|
||||
from .._checkparam import Validator as validator
|
||||
|
||||
__all__ = ['Tensor', 'RowTensor', 'SparseTensor']
|
||||
__all__ = ['Tensor', 'RowTensor', 'SparseTensor', 'CSRTensor']
|
||||
np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
||||
np.float32, np.float64, np.bool_, np.complex64, np.complex128)
|
||||
|
@ -2240,6 +2241,97 @@ class SparseTensor:
|
|||
def dense_shape(self):
|
||||
return self.__dense_shape
|
||||
|
||||
class CSRTensor(CSRTensor_):
|
||||
"""
|
||||
Constructs a sparse tensor in CSR (Compressed Sparse Row) format, with specified
|
||||
values indicated by `values` and row and column positions indicated by `indptr`
|
||||
and `indices`.
|
||||
|
||||
Alternatively, CSRTensor can be initialized by passing another CSRTensor as input.
|
||||
Currently this constructor can only be supported in PyNative Mode.
|
||||
|
||||
Note:
|
||||
This is an experimental feature and is subjected to change.
|
||||
|
||||
Args:
|
||||
indptr (Tensor): 1-D Tensor of size `shape[0] + 1`, which indicates the
|
||||
start and end point for `values` in each row. Default: None. If provided,
|
||||
must be :class:`mindspore.int16`, :class:`mindspore.int32` or :class:`mindspore.int64`.
|
||||
indices (Tensor): 1-D Tensor, which has the same length as `values`. `indices`
|
||||
indicates the which column `values` should be placed. Default: None. If provided,
|
||||
must be :class:`mindspore.int16`, :class:`mindspore.int32` or :class:`mindspore.int64`.
|
||||
values (Tensor): 1-D Tensor, which has the same length as `indices`. `values`
|
||||
stores the data for CSRTensor. Default: None.
|
||||
shape (Tuple): A tuple indicates the shape of the CSRTensor, its length must
|
||||
be `2`, as only 2-D CSRTensor is currently supported, and `shape[0]` must
|
||||
equal to `indptr[0] - 1`, which all equal to number of rows of the CSRTensor.
|
||||
csr_tensor (CSRTensor): A CSRTensor object.
|
||||
|
||||
Outputs:
|
||||
CSRTensor, with shape defined by `shape`, and dtype inferred from `value`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
>>> # initialize a csr_tensor with indptr, indices, values and shape
|
||||
>>> indptr = Tensor([0, 1, 2])
|
||||
>>> indices = Tensor([0, 1])
|
||||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||
>>> shape = (3, 4)
|
||||
>>> csr_tensor = CSRTensor(indptr, indices, values, shape)
|
||||
>>> # initialize a csr_tensor from another csr_tensor
|
||||
>>> csr_tensor_2 = CSRTensor(csr_tensor=csr_tensor)
|
||||
>>> # access a data member of CSRTensor
|
||||
>>> print(indptr == csr_tensor.indptr)
|
||||
>>> [ True True True]
|
||||
"""
|
||||
def __init__(self, indptr=None, indices=None, values=None, shape=None, csr_tensor=None):
|
||||
self.init_finished = False
|
||||
# Case 1: directly init a CSRTensor from another CSRTensor
|
||||
if indptr is None and indices is None and values is None and shape is None:
|
||||
if not isinstance(csr_tensor, (CSRTensor, CSRTensor_)):
|
||||
raise TypeError("If only one input provided, it must be a CSRTensor.")
|
||||
CSRTensor_.__init__(self, csr_tensor)
|
||||
# Case 2: init a CSRTensor from indptr, indices, values and shape
|
||||
else:
|
||||
if (indptr is None or indices is None or values is None or shape is None):
|
||||
raise TypeError("Inputs must follow: CSRTensor(indptr, indices, values, shape).")
|
||||
if not (isinstance(indptr, Tensor) and isinstance(indices, Tensor) \
|
||||
and isinstance(values, Tensor) and isinstance(shape, tuple)):
|
||||
raise TypeError("Inputs must follow: CSRTensor(tensor, tensor, tensor, tuple).")
|
||||
if len(shape) != 2 or shape[0] + 1 != indptr.shape[0] or shape[1] <= 0:
|
||||
raise ValueError("Shape length should be 2, shape[0] should equal to indptr.shape[0] - 1")
|
||||
if indptr.dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
||||
raise TypeError("indptr must have integer data type.")
|
||||
if indices.dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
||||
raise TypeError("indices must have integer data type.")
|
||||
CSRTensor_.__init__(self, indptr, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
def __repr__(self):
|
||||
"""Avoid PyTest Segfault when CSRTensor is not initialized."""
|
||||
if self.init_finished:
|
||||
return CSRTensor_.__repr__(self)
|
||||
return ''
|
||||
|
||||
@property
|
||||
def indptr(self):
|
||||
return Tensor(self._indptr)
|
||||
|
||||
@property
|
||||
def indices(self):
|
||||
return Tensor(self._indices)
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
return Tensor(self._values)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._shape
|
||||
|
||||
def to_tuple(self):
|
||||
return self.indptr, self.indices, self.values, self.shape
|
||||
|
||||
def _vm_compare(*args):
|
||||
"""Implement `vm_compare` for tensor."""
|
||||
|
|
|
@ -1443,6 +1443,113 @@ std::string AbstractSparseTensor::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
// CSRTensor
|
||||
TypePtr AbstractCSRTensor::BuildType() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
TypePtr element_type = element()->BuildType();
|
||||
return std::make_shared<CSRTensorType>(element_type);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractCSRTensor::Clone() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto clone = std::make_shared<AbstractCSRTensor>(element()->Clone());
|
||||
ShapePtr shp = shape();
|
||||
MS_EXCEPTION_IF_NULL(shp);
|
||||
MS_EXCEPTION_IF_NULL(indptr_);
|
||||
MS_EXCEPTION_IF_NULL(indices_);
|
||||
MS_EXCEPTION_IF_NULL(values_);
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_);
|
||||
auto indptr_clone = indptr_->Clone();
|
||||
auto indices_clone = indices_->Clone();
|
||||
auto value_clone = values_->Clone();
|
||||
auto dense_clone = dense_shape_->Clone();
|
||||
MS_EXCEPTION_IF_NULL(indptr_clone);
|
||||
MS_EXCEPTION_IF_NULL(indices_clone);
|
||||
MS_EXCEPTION_IF_NULL(value_clone);
|
||||
MS_EXCEPTION_IF_NULL(dense_clone);
|
||||
clone->set_shape(shp->Clone());
|
||||
clone->set_value(GetValueTrack());
|
||||
clone->set_indptr(indptr_clone->cast<AbstractTensorPtr>());
|
||||
clone->set_indices(indices_clone->cast<AbstractTensorPtr>());
|
||||
clone->set_values(value_clone->cast<AbstractTensorPtr>());
|
||||
clone->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
|
||||
return clone;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractCSRTensor::Broaden() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto broaden = std::make_shared<AbstractCSRTensor>(element()->Broaden());
|
||||
auto shp = shape();
|
||||
MS_EXCEPTION_IF_NULL(shp);
|
||||
MS_EXCEPTION_IF_NULL(indptr_);
|
||||
MS_EXCEPTION_IF_NULL(indices_);
|
||||
MS_EXCEPTION_IF_NULL(values_);
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_);
|
||||
auto indptr_clone = indptr_->Clone();
|
||||
auto indices_clone = indices_->Clone();
|
||||
auto value_clone = values_->Clone();
|
||||
auto dense_clone = dense_shape_->Clone();
|
||||
MS_EXCEPTION_IF_NULL(indptr_clone);
|
||||
MS_EXCEPTION_IF_NULL(indices_clone);
|
||||
MS_EXCEPTION_IF_NULL(value_clone);
|
||||
MS_EXCEPTION_IF_NULL(dense_clone);
|
||||
broaden->set_shape(shp->Clone());
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_indptr(indptr_clone->cast<AbstractTensorPtr>());
|
||||
broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
|
||||
broaden->set_values(value_clone->cast<AbstractTensorPtr>());
|
||||
broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractCSRTensor::BroadenWithShape() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto broaden = std::make_shared<AbstractCSRTensor>(element()->Broaden());
|
||||
auto this_shape = shape();
|
||||
MS_EXCEPTION_IF_NULL(this_shape);
|
||||
auto shp = this_shape->Clone();
|
||||
MS_EXCEPTION_IF_NULL(shp);
|
||||
shp->Broaden();
|
||||
broaden->set_shape(shp);
|
||||
broaden->set_value(kAnyValue);
|
||||
MS_EXCEPTION_IF_NULL(indptr_);
|
||||
MS_EXCEPTION_IF_NULL(indices_);
|
||||
MS_EXCEPTION_IF_NULL(values_);
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_);
|
||||
auto indptr_clone = indptr_->Clone();
|
||||
auto indices_clone = indices_->Clone();
|
||||
auto value_clone = values_->Clone();
|
||||
auto dense_clone = dense_shape_->Clone();
|
||||
MS_EXCEPTION_IF_NULL(indptr_clone);
|
||||
MS_EXCEPTION_IF_NULL(indices_clone);
|
||||
MS_EXCEPTION_IF_NULL(value_clone);
|
||||
MS_EXCEPTION_IF_NULL(dense_clone);
|
||||
broaden->set_indptr(indptr_clone->cast<AbstractTensorPtr>());
|
||||
broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
|
||||
broaden->set_values(value_clone->cast<AbstractTensorPtr>());
|
||||
broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
std::string AbstractCSRTensor::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
BaseShapePtr shape_track = GetShapeTrack();
|
||||
MS_EXCEPTION_IF_NULL(shape_track);
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto value_track = GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
MS_EXCEPTION_IF_NULL(indptr_);
|
||||
MS_EXCEPTION_IF_NULL(indices_);
|
||||
MS_EXCEPTION_IF_NULL(values_);
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_);
|
||||
buffer << type_name() << "("
|
||||
<< "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
|
||||
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
|
||||
<< ", indptr: " << indptr_->ToString() << ", indices: " << indices_->ToString() << ", values"
|
||||
<< values_->ToString() << ", dense_shape: " << dense_shape_->ToString();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
if (!other->isa<AbstractUMonad>()) {
|
||||
|
|
|
@ -1485,6 +1485,38 @@ class MS_CORE_API AbstractSparseTensor final : public AbstractUndetermined {
|
|||
AbstractTuplePtr dense_shape_;
|
||||
};
|
||||
|
||||
// CSRTensor
|
||||
class MS_CORE_API AbstractCSRTensor : public AbstractUndetermined {
|
||||
public:
|
||||
explicit AbstractCSRTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||
: AbstractUndetermined(element, shape) {}
|
||||
AbstractCSRTensor(const TypePtr &element_type, const ShapeVector &shape)
|
||||
: AbstractUndetermined(element_type, shape) {}
|
||||
~AbstractCSRTensor() override = default;
|
||||
MS_DECLARE_PARENT(AbstractCSRTensor, AbstractUndetermined)
|
||||
|
||||
const AbstractTensorPtr indptr() const { return indptr_; }
|
||||
void set_indptr(const AbstractTensorPtr &indptr) { indptr_ = indptr; }
|
||||
const AbstractTensorPtr indices() const { return indices_; }
|
||||
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
|
||||
const AbstractTensorPtr values() const { return values_; }
|
||||
void set_values(const AbstractTensorPtr &values) { values_ = values; }
|
||||
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
|
||||
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
|
||||
TypePtr BuildType() const override;
|
||||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
AbstractBasePtr BroadenWithShape() const;
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
private:
|
||||
AbstractTensorPtr indptr_;
|
||||
AbstractTensorPtr indices_;
|
||||
AbstractTensorPtr values_;
|
||||
AbstractTuplePtr dense_shape_;
|
||||
};
|
||||
|
||||
class AbstractMonad : public AbstractBase {
|
||||
public:
|
||||
~AbstractMonad() override = default;
|
||||
|
|
|
@ -149,6 +149,8 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
template <typename T>
|
||||
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -169,6 +171,17 @@ AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const
|
|||
AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRTensorGetIndptr(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -83,6 +83,7 @@ ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
|
|||
ABSTRACT_REPORT_NAME_TRAITS(Class)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(RowTensor)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(SparseTensor)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(CSRTensor)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -395,6 +395,105 @@ AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, co
|
|||
return sparse_tensor->dense_shape();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: three tensors and a tuple.
|
||||
constexpr auto kMakeCSRInputNum = 4;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kMakeCSRInputNum);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
|
||||
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
|
||||
|
||||
auto indices_dtype = indices->element()->BuildType();
|
||||
if (!indices_dtype->isa<Int>()) {
|
||||
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
|
||||
}
|
||||
auto indptr_shp = indptr->shape()->shape();
|
||||
if (indptr_shp.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "Indptr must be a 1 dimension tensor, but got a " << indptr_shp.size()
|
||||
<< " dimension tensor";
|
||||
}
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
if (indices_shp.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
|
||||
<< " dimension tensor";
|
||||
}
|
||||
auto values_shp = values->shape()->shape();
|
||||
if (values_shp.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
|
||||
<< " dimension tensor";
|
||||
}
|
||||
if (indices_shp[0] != values_shp[0]) {
|
||||
MS_EXCEPTION(ValueError) << "indices and values must have same size, but got: values length: " << values_shp[0]
|
||||
<< ", indices length " << indices_shp[0];
|
||||
}
|
||||
for (const auto &elem_type : shape->ElementsType()) {
|
||||
if (!elem_type->isa<Int>()) {
|
||||
MS_EXCEPTION(TypeError) << "The element type of shape must be Int, but got " << elem_type->ToString();
|
||||
}
|
||||
}
|
||||
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
auto shp = shape_value->value();
|
||||
ShapeVector shape_vec;
|
||||
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(shape_vec), [](const ValuePtr &e) -> int64_t {
|
||||
auto elem = GetValue<int64_t>(e);
|
||||
return elem;
|
||||
});
|
||||
|
||||
for (auto shape_elem : shape_vec) {
|
||||
if (shape_elem < 0) {
|
||||
MS_EXCEPTION(TypeError) << "The element of shape must be positive, but got " << shape_value->ToString();
|
||||
}
|
||||
}
|
||||
if (shape_vec[0] + 1 != indptr_shp[0]) {
|
||||
MS_EXCEPTION(ValueError) << "indptr must have length (1 + shape[0]), but got: " << indptr_shp[0];
|
||||
}
|
||||
auto ret = std::make_shared<AbstractCSRTensor>(values->element()->BuildType(), shape_vec);
|
||||
ret->set_indptr(indptr);
|
||||
ret->set_indices(indices);
|
||||
ret->set_values(values);
|
||||
ret->set_dense_shape(shape);
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
return CheckArg<T>(op_name, args_spec_list, 0);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->values());
|
||||
return csr_tensor->values();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRTensorGetIndptr(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->indptr());
|
||||
return csr_tensor->indptr();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->indices());
|
||||
return csr_tensor->indices();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->dense_shape());
|
||||
return csr_tensor->dense_shape();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
|
|
|
@ -215,11 +215,16 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, nullptr, true}},
|
||||
// RowTensor
|
||||
{prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, nullptr, true}},
|
||||
|
||||
{prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, nullptr, true}},
|
||||
{prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, nullptr, true}},
|
||||
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, nullptr, true}},
|
||||
{prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, nullptr, false}},
|
||||
// CSRTensor
|
||||
{prim::kPrimMakeCSRTensor, {InferImplMakeCSRTensor, nullptr, true}},
|
||||
{prim::kPrimCSRTensorGetValues, {InferImplCSRTensorGetValues, nullptr, true}},
|
||||
{prim::kPrimCSRTensorGetIndptr, {InferImplCSRTensorGetIndptr, nullptr, true}},
|
||||
{prim::kPrimCSRTensorGetIndices, {InferImplCSRTensorGetIndices, nullptr, true}},
|
||||
{prim::kPrimCSRTensorGetDenseShape, {InferImplCSRTensorGetDenseShape, nullptr, true}},
|
||||
// Comm Ops
|
||||
{prim::kPrimAllSwap, {InferImplAllSwap, nullptr, true}},
|
||||
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, nullptr, true}},
|
||||
|
|
|
@ -474,6 +474,13 @@ inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitiv
|
|||
inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
|
||||
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
|
||||
|
||||
// CSRTensor
|
||||
inline const PrimitivePtr kPrimMakeCSRTensor = std::make_shared<Primitive>("MakeCSRTensor");
|
||||
inline const PrimitivePtr kPrimCSRTensorGetValues = std::make_shared<Primitive>("CSRTensorGetValues");
|
||||
inline const PrimitivePtr kPrimCSRTensorGetIndptr = std::make_shared<Primitive>("CSRTensorGetIndptr");
|
||||
inline const PrimitivePtr kPrimCSRTensorGetIndices = std::make_shared<Primitive>("CSRTensorGetIndices");
|
||||
inline const PrimitivePtr kPrimCSRTensorGetDenseShape = std::make_shared<Primitive>("CSRTensorGetDenseShape");
|
||||
|
||||
// TensorList
|
||||
inline const PrimitivePtr kPrimTensorListFromTensor = std::make_shared<Primitive>("TensorListFromTensor");
|
||||
inline const PrimitivePtr kPrimTensorListReserve = std::make_shared<Primitive>("TensorListReserve");
|
||||
|
|
|
@ -191,4 +191,46 @@ bool SparseTensorType::operator==(const Type &other) const {
|
|||
}
|
||||
return *element_type_ == *other_elem_type;
|
||||
}
|
||||
|
||||
TypePtr CSRTensorType::DeepCopy() const {
|
||||
MS_EXCEPTION_IF_NULL(element_type_);
|
||||
if (IsGeneric()) {
|
||||
return std::make_shared<CSRTensorType>();
|
||||
}
|
||||
return std::make_shared<CSRTensorType>(element_type_->DeepCopy());
|
||||
}
|
||||
|
||||
std::string CSRTensorType::ToReprString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "CSRTensor";
|
||||
}
|
||||
return "CSRTensor[" + element_type_->ToReprString() + "]";
|
||||
}
|
||||
|
||||
std::string CSRTensorType::ToString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "CSRTensor";
|
||||
}
|
||||
return "CSRTensor[" + element_type_->ToString() + "]";
|
||||
}
|
||||
|
||||
std::string CSRTensorType::DumpText() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "CSRTensor";
|
||||
}
|
||||
return "CSRTensor[" + element_type_->DumpText() + "]";
|
||||
}
|
||||
|
||||
bool CSRTensorType::operator==(const Type &other) const {
|
||||
if (!IsSameObjectType(*this, other)) {
|
||||
return false;
|
||||
}
|
||||
auto other_elem_type = static_cast<const CSRTensorType &>(other).element_type_;
|
||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||
return true;
|
||||
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return *element_type_ == *other_elem_type;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -190,6 +190,45 @@ class MS_CORE_API SparseTensorType final : public Object {
|
|||
TypePtr element_type_;
|
||||
};
|
||||
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
|
||||
|
||||
/// \brief CSRTensorType defines interface for sparse tensor data type.
|
||||
class MS_CORE_API CSRTensorType : public Object {
|
||||
public:
|
||||
/// \brief Default constructor for CSRTensorType.
|
||||
CSRTensorType() : Object(kObjectTypeCSRTensorType, kObjectTypeUndeterminedType) {}
|
||||
|
||||
/// \brief Constructor for CSRTensorType.
|
||||
///
|
||||
/// \param[in] ele The element of CSRTensorType.
|
||||
explicit CSRTensorType(const TypePtr &ele)
|
||||
: Object(kObjectTypeCSRTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||
|
||||
/// \brief Destructor of CSRTensorType.
|
||||
~CSRTensorType() override = default;
|
||||
MS_DECLARE_PARENT(CSRTensorType, Object)
|
||||
|
||||
TypeId generic_type_id() const override { return kObjectTypeCSRTensorType; }
|
||||
|
||||
/// \brief Get the element of CSRTensorType object.
|
||||
///
|
||||
/// \return The element of CSRTensorType object.
|
||||
const TypePtr element() const { return element_type_; }
|
||||
|
||||
/// \brief Set the element of CSRTensorType object.
|
||||
///
|
||||
/// \param[in] element_type Define the element type to be set.
|
||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||
|
||||
TypePtr DeepCopy() const override;
|
||||
std::string ToString() const override;
|
||||
std::string ToReprString() const override;
|
||||
std::string DumpText() const override;
|
||||
bool operator==(const Type &other) const override;
|
||||
|
||||
private:
|
||||
TypePtr element_type_;
|
||||
};
|
||||
using CSRTensorTypePtr = std::shared_ptr<CSRTensorType>;
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
|
||||
|
|
|
@ -50,6 +50,7 @@ static std::unordered_map<TypeId, std::string> g_type_2_lable{
|
|||
{kObjectTypeTensorType, MS_TYPE2LABLE(kObjectTypeTensorType)},
|
||||
{kObjectTypeRowTensorType, MS_TYPE2LABLE(kObjectTypeRowTensorType)},
|
||||
{kObjectTypeSparseTensorType, MS_TYPE2LABLE(kObjectTypeSparseTensorType)},
|
||||
{kObjectTypeCSRTensorType, MS_TYPE2LABLE(kObjectTypeCSRTensorType)},
|
||||
{kObjectTypeUndeterminedType, MS_TYPE2LABLE(kObjectTypeUndeterminedType)},
|
||||
{kObjectTypeClass, MS_TYPE2LABLE(kObjectTypeClass)},
|
||||
{kObjectTypeDictionary, MS_TYPE2LABLE(kObjectTypeDictionary)},
|
||||
|
|
|
@ -86,13 +86,19 @@ enum TypeId : int {
|
|||
//
|
||||
// Monad Types
|
||||
//
|
||||
// Monad types is placed at the end of enum,
|
||||
// in order to keep fit with the type of existing model on the lite side.
|
||||
kMonadTypeBegin = kNumberTypeEnd,
|
||||
kObjectTypeMonad,
|
||||
kObjectTypeUMonad,
|
||||
kObjectTypeIOMonad,
|
||||
kMonadTypeEnd
|
||||
kMonadTypeEnd,
|
||||
//
|
||||
// Sparse Types
|
||||
//
|
||||
// Sparse types is placed at the end of enum,
|
||||
// in order to keep fit with the type of existing model on the lite side.
|
||||
kSparseTypeBegin = kMonadTypeEnd,
|
||||
kObjectTypeCSRTensorType,
|
||||
kSparseTypeEnd
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_
|
||||
|
|
|
@ -191,6 +191,23 @@ TypePtr SparseTensorStrToType(const std::string &type_name) {
|
|||
return std::make_shared<SparseTensorType>(element_type);
|
||||
}
|
||||
|
||||
TypePtr CSRTensorStrToType(const std::string &type_name) {
|
||||
if (type_name == "CSRTensor") {
|
||||
return std::make_shared<CSRTensorType>();
|
||||
}
|
||||
auto start = type_name.find_first_of('[') + 1;
|
||||
auto end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto element_str = type_name.substr(start, end - start);
|
||||
auto element_type = StringToType(element_str);
|
||||
if (element_type == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<CSRTensorType>(element_type);
|
||||
}
|
||||
|
||||
TypePtr UndeterminedStrToType(const std::string &type_name) {
|
||||
if (type_name == "Undetermined") {
|
||||
return std::make_shared<UndeterminedType>();
|
||||
|
@ -330,6 +347,7 @@ TypePtr GetTypeByStringStarts(const std::string &type_name) {
|
|||
{"Undetermined", [](const std::string &type_name) -> TypePtr { return UndeterminedStrToType(type_name); }},
|
||||
{"RowTensor", [](const std::string &type_name) -> TypePtr { return RowTensorStrToType(type_name); }},
|
||||
{"SparseTensor", [](const std::string &type_name) -> TypePtr { return SparseTensorStrToType(type_name); }},
|
||||
{"CSRTensor", [](const std::string &type_name) -> TypePtr { return CSRTensorStrToType(type_name); }},
|
||||
{"List", [](const std::string &type_name) -> TypePtr { return ListStrToType(type_name); }},
|
||||
{"Tuple", [](const std::string &type_name) -> TypePtr { return TupleStrToType(type_name); }},
|
||||
{"Function", [](const std::string &type_name) -> TypePtr { return FunctionStrToType(type_name); }}};
|
||||
|
|
|
@ -673,5 +673,31 @@ TypeId Tensor::set_data_type(const TypeId data_type) {
|
|||
}
|
||||
return data_type;
|
||||
}
|
||||
|
||||
CSRTensor::CSRTensor(const TensorPtr indptr, const TensorPtr indices, const TensorPtr values, const ShapeVector &shape)
|
||||
: indptr_(indptr), indices_(indices), values_(values), shape_(shape) {}
|
||||
|
||||
bool CSRTensor::operator==(const CSRTensor &csr_tensor) const { return (&csr_tensor == this); }
|
||||
|
||||
std::string CSRTensor::ToString() const {
|
||||
std::ostringstream buf;
|
||||
MS_EXCEPTION_IF_NULL(values_);
|
||||
MS_EXCEPTION_IF_NULL(indices_);
|
||||
MS_EXCEPTION_IF_NULL(indptr_);
|
||||
auto dtype = values_->Dtype();
|
||||
buf << "CSRTensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", indptr=";
|
||||
buf << indptr_->ToString() << ", indices=" << indices_->ToString() << ", values=";
|
||||
buf << values_->ToString() << ")";
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr CSRTensor::ToAbstract() {
|
||||
auto dtype = values_->Dtype();
|
||||
if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
|
||||
MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
|
||||
}
|
||||
auto abs_csr_tensor = std::make_shared<abstract::AbstractCSRTensor>(dtype, shape_);
|
||||
return abs_csr_tensor;
|
||||
}
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -547,6 +547,63 @@ class MS_CORE_API Tensor final : public MetaTensor {
|
|||
};
|
||||
using TensorPtr = std::shared_ptr<Tensor>;
|
||||
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
|
||||
|
||||
// CSRTensor entity class
|
||||
class MS_CORE_API CSRTensor : public MetaTensor {
|
||||
public:
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
|
||||
/// \brief Create CSRTensor with given data type from another tensor.
|
||||
///
|
||||
/// \param[in] indptr [Tensor] The indices pointer.
|
||||
/// \param[in] indices [Tensor] The indices.
|
||||
/// \param[in] values [Tensor] The values.
|
||||
/// \param[in] shape The shape represented by ShapeVector of the CSRensor.
|
||||
CSRTensor(const TensorPtr indptr, const TensorPtr indices, const TensorPtr values, const ShapeVector &shape);
|
||||
|
||||
/// Destructor of CSRTensor.
|
||||
~CSRTensor() override = default;
|
||||
|
||||
/// \brief Gets CSRTensor's indptr.
|
||||
///
|
||||
/// \return [TensorPtr] The indices pointer.
|
||||
TensorPtr GetIndptr() { return indptr_; }
|
||||
|
||||
/// \brief Gets CSRTensor's indices.
|
||||
///
|
||||
/// \return [TensorPtr] The indices.
|
||||
TensorPtr GetIndices() { return indices_; }
|
||||
|
||||
/// \brief Gets CSRTensor's values.
|
||||
///
|
||||
/// \return [TensorPtr] The values.
|
||||
TensorPtr GetValues() { return values_; }
|
||||
|
||||
/// \brief Gets CSRTensor's shape.
|
||||
///
|
||||
/// \return [ShapeVector] The shape of the tensor.
|
||||
const ShapeVector &shape() const { return shape_; }
|
||||
|
||||
/// \brief Compare two tensor objects to see if they have same data type, shape and data address.
|
||||
///
|
||||
/// \param[in] tensor The Tensor object to be compared.
|
||||
/// \return True if having same type, shape and data address, otherwise false.
|
||||
bool operator==(const CSRTensor &csr_tensor) const;
|
||||
|
||||
/// \brief Get display information of this Tensor.
|
||||
///
|
||||
/// \return The display information of this Tensor.
|
||||
std::string ToString() const override;
|
||||
|
||||
TypePtr Dtype() const { return values_->Dtype(); }
|
||||
|
||||
private:
|
||||
TensorPtr indptr_;
|
||||
TensorPtr indices_;
|
||||
TensorPtr values_;
|
||||
ShapeVector shape_{};
|
||||
};
|
||||
using CSRTensorPtr = std::shared_ptr<CSRTensor>;
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -394,6 +394,12 @@ sparse_tensor_get_values = Primitive('SparseTensorGetValues')
|
|||
sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
|
||||
sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
|
||||
|
||||
make_csr_tensor = Primitive('MakeCSRTensor')
|
||||
csr_tensor_get_values = Primitive('CSRTensorGetValues')
|
||||
csr_tensor_get_indices = Primitive('CSRTensorGetIndices')
|
||||
csr_tensor_get_indptr = Primitive('CSRTensorGetIndptr')
|
||||
csr_tensor_get_shape = Primitive('CSRTensorGetDenseShape')
|
||||
|
||||
tensor_operator_registry.register('all', P.ReduceAll)
|
||||
tensor_operator_registry.register('any', P.ReduceAny)
|
||||
tensor_operator_registry.register('abs', P.Abs)
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""smoke tests for CSR operations"""
|
||||
|
||||
import pytest
|
||||
from mindspore import Tensor, CSRTensor, ms_function
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
def compare_csr(csr1, csr2):
|
||||
assert isinstance(csr1, CSRTensor)
|
||||
assert isinstance(csr2, CSRTensor)
|
||||
assert (csr1.indptr.asnumpy() == csr2.indptr.asnumpy()).all()
|
||||
assert (csr1.indices.asnumpy() == csr2.indices.asnumpy()).all()
|
||||
assert (csr1.values.asnumpy() == csr2.values.asnumpy()).all()
|
||||
assert csr1.shape == csr2.shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_make_csr():
|
||||
"""
|
||||
Feature: Test CSRTensor Constructor in Graph and PyNative.
|
||||
Description: Test CSRTensor(indptr, indices, values, shape) and CSRTensor(CSRTensor)
|
||||
Expectation: Success.
|
||||
"""
|
||||
indptr = Tensor([0, 1, 2])
|
||||
indices = Tensor([0, 1])
|
||||
values = Tensor([1, 2], dtype=mstype.float32)
|
||||
shape = (2, 6)
|
||||
def test_pynative():
|
||||
return CSRTensor(indptr, indices, values, shape)
|
||||
test_graph = ms_function(test_pynative)
|
||||
|
||||
csr1 = test_pynative()
|
||||
csr2 = test_graph()
|
||||
compare_csr(csr1, csr2)
|
||||
csr3 = CSRTensor(csr_tensor=csr2)
|
||||
compare_csr(csr3, csr2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_attr():
|
||||
"""
|
||||
Feature: Test CSRTensor GetAttr in Graph and PyNative.
|
||||
Description: Test CSRTensor.indptr, CSRTensor.indices, CSRTensor.values, CSRTensor.shape.
|
||||
Expectation: Success.
|
||||
"""
|
||||
indptr = Tensor([0, 1, 2])
|
||||
indices = Tensor([0, 1])
|
||||
values = Tensor([1, 2], dtype=mstype.float32)
|
||||
shape = (2, 6)
|
||||
def test_pynative():
|
||||
csr = CSRTensor(indptr, indices, values, shape)
|
||||
return csr.indptr, csr.indices, csr.values, csr.shape
|
||||
test_graph = ms_function(test_pynative)
|
||||
|
||||
csr1_tuple = test_pynative()
|
||||
csr2_tuple = test_graph()
|
||||
|
||||
csr1 = CSRTensor(*csr1_tuple)
|
||||
csr2 = CSRTensor(*csr2_tuple)
|
||||
compare_csr(csr1, csr2)
|
Loading…
Reference in New Issue