!31737 adapt csrtensor runtime

Merge pull request !31737 from wangrao124/pr_csr_param
This commit is contained in:
i-robot 2022-03-31 03:30:20 +00:00 committed by Gitee
commit fcb0d72e38
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 254 additions and 56 deletions

View File

@ -28,6 +28,9 @@ namespace opt {
using CSRTensor = mindspore::tensor::CSRTensor;
using CSRTensorPtr = mindspore::tensor::CSRTensorPtr;
constexpr auto kCSRValueNodeNum = 2;
constexpr auto kSparseAttrIndex = 1;
// Convert CSRTensor Parameter or ValueNode to Tuple by setting its abstract.
void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) {
MS_EXCEPTION_IF_NULL(sparse);
@ -77,6 +80,9 @@ bool SplitParameter(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs,
MS_EXCEPTION_IF_NULL(node);
auto node_abs = node->abstract();
MS_EXCEPTION_IF_NULL(node_abs);
static HashMap<AnfNodePtr, std::vector<AnfNodePtr>> csr_params_map;
auto param = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
if (node_abs->isa<abstract::AbstractCSRTensor>()) {
auto param_abs = node_abs->cast<abstract::AbstractCSRTensorPtr>();
MS_EXCEPTION_IF_NULL(param_abs);
@ -94,6 +100,20 @@ bool SplitParameter(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs,
// Set CSRTensor Parameter abstract to Tensor by its values.
node->set_abstract(param_abs->values()->Broaden());
new_inputs->push_back(node);
// Set csr_params_map
if (csr_params_map.find(node) == csr_params_map.end()) {
csr_params_map[node].emplace_back(new_indptr);
csr_params_map[node].emplace_back(new_indices);
}
return true;
// If the cnode has a csr_tensor_param which has been split, use the map to find its indptr and indices.
} else if (node_abs->isa<abstract::AbstractTensor>() && csr_params_map.find(node) != csr_params_map.end()) {
if (csr_params_map[node].size() != kCSRValueNodeNum) {
MS_LOG(ERROR) << "csr_params_map[" << node->DebugString() << "] has " << csr_params_map[node].size()
<< " inputs, but expect two inputs! They are all added in new_inputs.";
}
new_inputs->insert(new_inputs->end(), csr_params_map[node].begin(), csr_params_map[node].end());
new_inputs->push_back(node);
return true;
}
return false;
@ -103,8 +123,8 @@ bool SplitCNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs) {
auto cnode = node->cast<CNodePtr>();
auto sparse_prim = common::AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(sparse_prim);
// Currently, only MakeCSR and MakeTuple nodes can be split.
if (make_sparse_set.count(sparse_prim->name()) == 0 && sparse_prim->name().compare(prim::kPrimMakeTuple->name()) != 0)
// Currently, only MakeCSR/MakeCOO and MakeTuple nodes can be split.
if (make_sparse_set.count(sparse_prim->name()) <= 0 && sparse_prim->name().compare(prim::kPrimMakeTuple->name()) != 0)
return false;
auto sparse_inputs = cnode->inputs();
@ -116,8 +136,10 @@ bool SplitCNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs) {
return true;
}
std::vector<AbstractBasePtr> GetAbstractList(const AnfNodePtr &node, const std::string &prim_name) {
std::vector<AbstractBasePtr> GetAbstractList(const AnfNodePtr &node, const PrimitivePtr &prim) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(prim);
std::string prim_name = prim->name();
if (prim_name == prim::kPrimMakeCSRTensor->name()) {
auto abs_sparse = dyn_cast<abstract::AbstractCSRTensor>(node->abstract());
MS_EXCEPTION_IF_NULL(abs_sparse);
@ -130,6 +152,85 @@ std::vector<AbstractBasePtr> GetAbstractList(const AnfNodePtr &node, const std::
return {};
}
CNodePtr ConvertMakeSparseToMakeTuple(const AnfNodePtr &node, const KernelGraphPtr &kernel_graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(kernel_graph);
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
(void)inputs.insert(inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
auto new_node = NewCNode(inputs, cnode->func_graph());
std::vector<AbstractBasePtr> abstract_list = GetAbstractList(node, common::AnfAlgo::GetCNodePrimitive(cnode));
auto abs_res = std::make_shared<abstract::AbstractTuple>(abstract_list);
new_node->set_abstract(abs_res);
new_node->set_scope(cnode->scope());
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(cnode, new_node);
}
return new_node;
}
CNodePtr ConvertSparseGetAttrToTupleGetItem(int64_t index, const AnfNodePtr &node, const KernelGraphPtr &kernel_graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(kernel_graph);
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
if (inputs.size() <= kSparseAttrIndex) {
MS_LOG(EXCEPTION) << "For SparseGetAttr, CNode must have 2 inputs (Prim, Sparse)";
}
AbstractCSRToAbstractTuple(inputs[kSparseAttrIndex]);
auto index_node = NewValueNode(index);
AbstractBasePtr index_abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
index_node->set_abstract(index_abs);
auto new_node =
NewCNode({NewValueNode(prim::kPrimTupleGetItem), inputs[kSparseAttrIndex], index_node}, cnode->func_graph());
new_node->set_abstract(node->abstract());
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(cnode, new_node);
}
return new_node;
}
CNodePtr FetchInputsForSparseOP(const AnfNodePtr &node, const KernelGraphPtr &kernel_graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(kernel_graph);
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->GetAttr("has_been_split") != nullptr) {
MS_LOG(INFO) << "Do not process CNode " << cnode << " (" << cnode->DebugString() << "), because it has been split.";
return nullptr;
}
const auto &inputs = cnode->inputs();
std::vector<AnfNodePtr> new_inputs;
new_inputs.push_back(inputs[0]);
for (size_t i = 1; i < inputs.size(); ++i) {
if (inputs[i]->isa<CNode>()) {
if (SplitCNode(inputs[i], &new_inputs)) continue;
} else if (inputs[i]->isa<ValueNode>()) {
if (SplitValueNode(inputs[i], &new_inputs, kernel_graph)) continue;
} else if (inputs[i]->isa<Parameter>()) {
// 1. Split CSRTensor param to multiple tensors.
// 2. Set CSRTensor abstract to AbstractTensor that is related its values.
if (SplitParameter(inputs[i], &new_inputs, kernel_graph)) continue;
}
new_inputs.push_back(inputs[i]);
}
auto new_node = NewCNode(new_inputs, cnode->func_graph());
new_node->set_abstract(node->abstract());
// Set attr "has_been_split" to prevent the node is split more than once.
new_node->AddAttr("has_been_split", MakeValue(true));
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(cnode, new_node);
}
return new_node;
}
const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
@ -143,56 +244,13 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
MS_EXCEPTION_IF_NULL(prim);
std::string prim_name = prim->name();
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
// cnode is a MakeSparse node
if (make_sparse_set.find(prim_name) != make_sparse_set.end()) {
std::vector<AnfNodePtr> inputs;
(void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
(void)inputs.insert(inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
auto new_node = cnode->func_graph()->NewCNode(inputs);
std::vector<AbstractBasePtr> abstract_list = GetAbstractList(node, prim_name);
auto abs_res = std::make_shared<abstract::AbstractTuple>(abstract_list);
new_node->set_abstract(abs_res);
new_node->set_scope(cnode->scope());
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(cnode, new_node);
}
return new_node;
// cnode is a SparseGetAttr node
return ConvertMakeSparseToMakeTuple(node, kernel_graph);
} else if (sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
const auto &inputs = cnode->inputs();
// Inputs should be [sparse_getattr, sparse]
if (inputs.size() <= 1) {
MS_LOG_EXCEPTION << "For SparseGetAttr, CNode must have 2 inputs (Prim, Sparse)";
}
constexpr size_t sparse_index = 1;
AbstractCSRToAbstractTuple(inputs[sparse_index]);
int64_t index = sparse_attr_map.at(prim_name);
auto cons_node = NewValueNode(index);
AbstractBasePtr aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
cons_node->set_abstract(aptr);
auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), inputs[sparse_index], cons_node}, func_graph);
new_node->set_abstract(node->abstract());
return new_node;
// ComputeSparse node: SparseTensorDenseMatmul, CSRMul, CSRReduceSum
return ConvertSparseGetAttrToTupleGetItem(sparse_attr_map.at(prim_name), node, kernel_graph);
} else if (sparse_op_set.find(prim_name) != sparse_op_set.end()) {
const auto &inputs = cnode->inputs();
std::vector<AnfNodePtr> new_inputs;
new_inputs.push_back(inputs[0]);
for (size_t i = 1; i < inputs.size(); ++i) {
if (inputs[i]->isa<CNode>()) {
if (SplitCNode(inputs[i], &new_inputs)) continue;
} else if (inputs[i]->isa<ValueNode>()) {
if (SplitValueNode(inputs[i], &new_inputs, kernel_graph)) continue;
} else if (inputs[i]->isa<Parameter>()) {
if (SplitParameter(inputs[i], &new_inputs, kernel_graph)) continue;
}
new_inputs.push_back(inputs[i]);
}
auto new_node = cnode->func_graph()->NewCNode(new_inputs);
new_node->set_abstract(node->abstract());
return new_node;
return FetchInputsForSparseOP(node, kernel_graph);
}
return nullptr;
}
} // namespace opt

View File

@ -23,6 +23,10 @@
namespace mindspore {
namespace opt {
// Process SparseOPs:
// 1. Convert "MakeCSRTensor/MakeCOOTensor/..." to MakeTuple
// 2. Convert "CSRTensorGetIndptr/..." to TupleGetItem
// 3. Process inputs for SparseOPs, e.g., split CSRTensor input to multiple tensor inputs.
class SparseProcess : public PatternProcessPass {
public:
explicit SparseProcess(bool multigraph = true) : PatternProcessPass("sparse_process", multigraph) {}

View File

@ -225,6 +225,11 @@ class COMMON_EXPORT AnfAlgo {
// Get the real output node and indexes of get item, make tuple, depend, load.
static AnfNodePtr GetTupleIndexes(const AnfNodePtr &node, std::vector<size_t> *index_stack);
static bool IsNopNode(const AnfNodePtr &node);
static bool CheckAbsCSRTensor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->abstract());
return node->abstract()->isa<abstract::AbstractCSRTensor>();
}
};
} // namespace common
} // namespace mindspore

View File

@ -111,6 +111,17 @@ void ValueTupleToValue(const ValuePtr &value, std::vector<ValuePtr> *const value
(void)values->emplace_back(element);
}
}
} else if (value->isa<tensor::CSRTensor>()) {
auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
MS_EXCEPTION_IF_NULL(csr_tensor);
MS_EXCEPTION_IF_NULL(csr_tensor->GetIndptr());
MS_EXCEPTION_IF_NULL(csr_tensor->GetIndices());
MS_EXCEPTION_IF_NULL(csr_tensor->GetValues());
values->emplace_back(csr_tensor->GetIndptr());
values->emplace_back(csr_tensor->GetIndices());
values->emplace_back(csr_tensor->GetValues());
std::transform(csr_tensor->shape().begin(), csr_tensor->shape().end(), std::back_inserter(*values),
[](int64_t n) { return std::make_shared<Int64Imm>(n); });
} else {
(void)values->emplace_back(value);
}

View File

@ -322,7 +322,12 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
}
// Set type.
TypeId type_id = common::AnfAlgo::GetOutputInferDataType(node, front_node_with_index.second);
TypeId type_id = kTypeUnknown;
if (common::AnfAlgo::CheckAbsCSRTensor(node)) {
type_id = node->abstract()->cast<abstract::AbstractCSRTensorPtr>()->GetTypeIdAt(front_node_with_index.second);
} else {
type_id = common::AnfAlgo::GetOutputInferDataType(node, front_node_with_index.second);
}
if (builder->GetAllOutputDeviceTypes().size() > front_node_with_index.second) {
builder->SetOutputDeviceType(type_id, front_node_with_index.second);
} else {
@ -334,7 +339,12 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
}
// Create device tensor.
size_t size = AnfAlgo::GetOutputTensorMemSize(node, front_node_with_index.second);
size_t size = 0;
if (common::AnfAlgo::CheckAbsCSRTensor(node)) {
size = node->cast<ValueNodePtr>()->value()->cast<tensor::CSRTensorPtr>()->GetSizeAt(front_node_with_index.second);
} else {
size = AnfAlgo::GetOutputTensorMemSize(node, front_node_with_index.second);
}
device::DeviceAddressPtr address =
device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id, ShapeVector());
MS_EXCEPTION_IF_NULL(address);
@ -570,6 +580,10 @@ KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const
if (front_node != nullptr) {
MS_LOG(DEBUG) << "Front node:" << front_node->DebugString() << " index:0"
<< " for backend node:" << backend_node->DebugString();
// Adapt CSRTensor to new runtime
if (common::AnfAlgo::CheckAbsCSRTensor(front_node)) {
return {front_node, kCsrTensorValuesIndex};
}
return {front_node, 0};
}
const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);

View File

@ -59,6 +59,7 @@ constexpr size_t kCsrTensorIndPtrIndex = 0;
constexpr size_t kCsrTensorIndicesIndex = 1;
constexpr size_t kCsrTensorValuesIndex = 2;
constexpr size_t kCsrTensorDenseShapeIndex = 3;
constexpr size_t kCsrParamOutputSize = 3;
constexpr size_t kCooTensorIndicesIndex = 0;
constexpr size_t kCooTensorValuesIndex = 1;
constexpr size_t kCooTensorDenseShapeIndex = 2;

View File

@ -1493,6 +1493,13 @@ void ControlNodeScheduler::LinkDataArrowForKernelActor(const GraphCompilerInfo &
}
}
void SetCSRTensorIndex(AnfNodePtr front_node, KernelWithIndex *front_node_with_index) {
if (front_node != nullptr && common::AnfAlgo::CheckAbsCSRTensor(front_node)) {
MS_EXCEPTION_IF_NULL(front_node_with_index);
front_node_with_index->second = kCsrTensorValuesIndex;
}
}
void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &graph, ControlActor *const entrance_actor,
const ControlNodeParserPtr &parser) {
MS_EXCEPTION_IF_NULL(graph);
@ -1533,6 +1540,9 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
from_node_with_index = tuple_node_with_index;
}
// Adapt CSRTensor to new runtime
SetCSRTensorIndex(front_node, &from_node_with_index);
if (common::AnfAlgo::CheckPrimitiveType(from_node_with_index.first, prim::kPrimTupleGetItem)) {
MS_LOG(WARNING) << "Input node:" << from_node_with_index.first->DebugString()
<< " for graph:" << graph->ToString() << " is a tuple get item";

View File

@ -58,6 +58,33 @@ bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePt
return false;
}
void SetCSRParamAddr(const AnfNodePtr &node, size_t output_size, const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->abstract());
MS_EXCEPTION_IF_NULL(device_context);
auto abs_csr_tensor = node->abstract()->cast<abstract::AbstractCSRTensorPtr>();
MS_EXCEPTION_IF_NULL(abs_csr_tensor);
for (size_t i = 0; i < output_size; ++i) {
auto abs_tensor = abs_csr_tensor->GetAbsTensorAt(i);
MS_EXCEPTION_IF_NULL(abs_tensor);
TypeId type_id = abs_tensor->BuildType()->type_id();
ShapeVector shape_vec = abs_tensor->shape()->shape();
std::vector<size_t> res;
std::transform(res.begin(), res.end(), std::back_inserter(shape_vec),
[](int64_t num) { return static_cast<size_t>(num); });
size_t type_size = GetTypeByte(TypeIdToType(type_id));
size_t tensor_size = std::accumulate(res.begin(), res.end(), type_size, std::multiplies<size_t>());
auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, type_id,
trans::GetRuntimePaddingShape(node, i));
device_address->set_from_persistent_mem(node->isa<Parameter>());
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(node) << " addr:" << device_address;
AnfAlgo::SetOutputAddr(device_address, i, node.get());
}
}
void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
@ -93,6 +120,10 @@ void CreateParameterDeviceAddress(const DeviceContext *device_context, const Ker
// Create device address for anf node in nodes_list
for (const auto &item : nodes_list) {
auto output_size = common::AnfAlgo::GetOutputTensorNum(item);
if (common::AnfAlgo::CheckAbsCSRTensor(item)) {
SetCSRParamAddr(item, kCsrParamOutputSize, device_context);
continue;
}
for (size_t index = 0; index < output_size; index++) {
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
if (output_type_id == kTypeUnknown) {

View File

@ -156,10 +156,9 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
}
// If the node is a call, the outputs num should get from the abstract.
if (AnfAlgo::IsCallNode(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
auto abstract = node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
outputs_num = AnfAlgo::GetOutputNumByAbstract(abstract);
if (AnfAlgo::IsCallNode(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
AnfAlgo::CheckAbsCSRTensor(node)) {
outputs_num = AnfAlgo::GetOutputNumByAbstract(node->abstract());
}
// The output may be the tuple of node, so need visit all the outputs of node.

View File

@ -1848,6 +1848,36 @@ std::string AbstractCSRTensor::ToString() const {
return buffer.str();
}
const AbstractTensorPtr AbstractCSRTensor::GetAbsTensorAt(size_t index) const {
if (index == kIndptrIdx) {
return indptr_;
} else if (index == kIndicesIdx) {
return indices_;
} else if (index == kValuesIdx) {
return values_;
}
MS_LOG(EXCEPTION) << "Invalid index: " << index << " for abstract: " << ToString();
return nullptr;
}
const TypeId AbstractCSRTensor::GetTypeIdAt(size_t index) const {
if (index == kIndptrIdx) {
MS_EXCEPTION_IF_NULL(indptr_);
return indptr_->element()->BuildType()->type_id();
} else if (index == kIndicesIdx) {
MS_EXCEPTION_IF_NULL(indices_);
return indices_->element()->BuildType()->type_id();
} else if (index == kValuesIdx) {
MS_EXCEPTION_IF_NULL(values_);
return values_->element()->BuildType()->type_id();
} else if (index >= kShapeIdx && index < kShapeIdx + dense_shape_->elements().size()) {
MS_EXCEPTION_IF_NULL(dense_shape_);
return dense_shape_->elements()[index - kShapeIdx]->BuildType()->type_id();
}
MS_LOG(EXCEPTION) << "Invalid index: " << index << " for abstract: " << ToString();
return kTypeUnknown;
}
AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
if (!other->isa<AbstractUMonad>()) {

View File

@ -1596,7 +1596,8 @@ class MS_CORE_API AbstractCSRTensor : public AbstractUndetermined {
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
const AbstractTensorPtr GetAbsTensorAt(size_t index) const;
const TypeId GetTypeIdAt(size_t index) const;
std::string ToString() const override;
private:
@ -1604,6 +1605,10 @@ class MS_CORE_API AbstractCSRTensor : public AbstractUndetermined {
AbstractTensorPtr indices_;
AbstractTensorPtr values_;
AbstractTuplePtr dense_shape_;
static constexpr size_t kIndptrIdx = 0;
static constexpr size_t kIndicesIdx = 1;
static constexpr size_t kValuesIdx = 2;
static constexpr size_t kShapeIdx = 3;
};
using AbstractCSRTensorPtr = std::shared_ptr<AbstractCSRTensor>;

View File

@ -729,6 +729,23 @@ abstract::AbstractBasePtr CSRTensor::ToAbstract() {
return abs_csr_tensor;
}
const size_t CSRTensor::GetSizeAt(size_t index) const {
if (index == kIndptrIdx) {
MS_EXCEPTION_IF_NULL(indptr_);
return indptr_->data().nbytes();
} else if (index == kIndicesIdx) {
MS_EXCEPTION_IF_NULL(indices_);
return indices_->data().nbytes();
} else if (index == kValuesIdx) {
MS_EXCEPTION_IF_NULL(values_);
return values_->data().nbytes();
} else if (index >= kIndicesIdx && index < kShapeIdx + shape().size()) {
return sizeof(int64_t);
}
MS_LOG(EXCEPTION) << "Invalid index: " << index << " for CSRTensor: " << ToString();
return kTypeUnknown;
}
std::string COOTensor::ToString() const {
std::ostringstream buf;
MS_EXCEPTION_IF_NULL(indices_);

View File

@ -614,6 +614,8 @@ class MS_CORE_API CSRTensor : public MetaSparseTensor {
return false;
}
const size_t GetSizeAt(size_t index) const;
/// \brief Get display information of this Tensor.
///
/// \return The display information of this Tensor.
@ -623,6 +625,10 @@ class MS_CORE_API CSRTensor : public MetaSparseTensor {
TensorPtr indptr_;
TensorPtr indices_;
TensorPtr values_;
static constexpr size_t kIndptrIdx = 0;
static constexpr size_t kIndicesIdx = 1;
static constexpr size_t kValuesIdx = 2;
static constexpr size_t kShapeIdx = 3;
};
using CSRTensorPtr = std::shared_ptr<CSRTensor>;

View File

@ -329,6 +329,10 @@ size_t AnfUtils::GetOutputTensorNum(const AnfNodePtr &node) {
res = tuple_type->size();
} else if (type->isa<TypeNone>()) {
res = 0;
} else if (type->isa<CSRTensorType>()) {
// Currently, CSRTensor only supports 2-D matrix (shape has 2 values). 5 outputs = 3 Tensors + 2 shape values.
constexpr size_t kCSRTensorOutputNum = 5;
res = kCSRTensorOutputNum;
} else {
res = 1;
}

View File

@ -25,6 +25,9 @@ csr_mv_op_info = AkgGpuRegOp("CSRMV") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
DataType.F32_Default, \
DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
DataType.F32_Default, \
DataType.F32_Default) \
.get_op_info()
@op_info_register(csr_mv_op_info)