!31737 adapt csrtensor runtime
Merge pull request !31737 from wangrao124/pr_csr_param
This commit is contained in:
commit
fcb0d72e38
|
@ -28,6 +28,9 @@ namespace opt {
|
|||
using CSRTensor = mindspore::tensor::CSRTensor;
|
||||
using CSRTensorPtr = mindspore::tensor::CSRTensorPtr;
|
||||
|
||||
constexpr auto kCSRValueNodeNum = 2;
|
||||
constexpr auto kSparseAttrIndex = 1;
|
||||
|
||||
// Convert CSRTensor Parameter or ValueNode to Tuple by setting its abstract.
|
||||
void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) {
|
||||
MS_EXCEPTION_IF_NULL(sparse);
|
||||
|
@ -77,6 +80,9 @@ bool SplitParameter(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs,
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_abs = node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(node_abs);
|
||||
static HashMap<AnfNodePtr, std::vector<AnfNodePtr>> csr_params_map;
|
||||
auto param = node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
if (node_abs->isa<abstract::AbstractCSRTensor>()) {
|
||||
auto param_abs = node_abs->cast<abstract::AbstractCSRTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_abs);
|
||||
|
@ -94,6 +100,20 @@ bool SplitParameter(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs,
|
|||
// Set CSRTensor Parameter abstract to Tensor by its values.
|
||||
node->set_abstract(param_abs->values()->Broaden());
|
||||
new_inputs->push_back(node);
|
||||
// Set csr_params_map
|
||||
if (csr_params_map.find(node) == csr_params_map.end()) {
|
||||
csr_params_map[node].emplace_back(new_indptr);
|
||||
csr_params_map[node].emplace_back(new_indices);
|
||||
}
|
||||
return true;
|
||||
// If the cnode has a csr_tensor_param which has been split, use the map to find its indptr and indices.
|
||||
} else if (node_abs->isa<abstract::AbstractTensor>() && csr_params_map.find(node) != csr_params_map.end()) {
|
||||
if (csr_params_map[node].size() != kCSRValueNodeNum) {
|
||||
MS_LOG(ERROR) << "csr_params_map[" << node->DebugString() << "] has " << csr_params_map[node].size()
|
||||
<< " inputs, but expect two inputs! They are all added in new_inputs.";
|
||||
}
|
||||
new_inputs->insert(new_inputs->end(), csr_params_map[node].begin(), csr_params_map[node].end());
|
||||
new_inputs->push_back(node);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -103,8 +123,8 @@ bool SplitCNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs) {
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
auto sparse_prim = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(sparse_prim);
|
||||
// Currently, only MakeCSR and MakeTuple nodes can be split.
|
||||
if (make_sparse_set.count(sparse_prim->name()) == 0 && sparse_prim->name().compare(prim::kPrimMakeTuple->name()) != 0)
|
||||
// Currently, only MakeCSR/MakeCOO and MakeTuple nodes can be split.
|
||||
if (make_sparse_set.count(sparse_prim->name()) <= 0 && sparse_prim->name().compare(prim::kPrimMakeTuple->name()) != 0)
|
||||
return false;
|
||||
|
||||
auto sparse_inputs = cnode->inputs();
|
||||
|
@ -116,8 +136,10 @@ bool SplitCNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs) {
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<AbstractBasePtr> GetAbstractList(const AnfNodePtr &node, const std::string &prim_name) {
|
||||
std::vector<AbstractBasePtr> GetAbstractList(const AnfNodePtr &node, const PrimitivePtr &prim) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::string prim_name = prim->name();
|
||||
if (prim_name == prim::kPrimMakeCSRTensor->name()) {
|
||||
auto abs_sparse = dyn_cast<abstract::AbstractCSRTensor>(node->abstract());
|
||||
MS_EXCEPTION_IF_NULL(abs_sparse);
|
||||
|
@ -130,6 +152,85 @@ std::vector<AbstractBasePtr> GetAbstractList(const AnfNodePtr &node, const std::
|
|||
return {};
|
||||
}
|
||||
|
||||
CNodePtr ConvertMakeSparseToMakeTuple(const AnfNodePtr &node, const KernelGraphPtr &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
(void)inputs.insert(inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
|
||||
auto new_node = NewCNode(inputs, cnode->func_graph());
|
||||
std::vector<AbstractBasePtr> abstract_list = GetAbstractList(node, common::AnfAlgo::GetCNodePrimitive(cnode));
|
||||
auto abs_res = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
new_node->set_abstract(abs_res);
|
||||
new_node->set_scope(cnode->scope());
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(cnode, new_node);
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
|
||||
CNodePtr ConvertSparseGetAttrToTupleGetItem(int64_t index, const AnfNodePtr &node, const KernelGraphPtr &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() <= kSparseAttrIndex) {
|
||||
MS_LOG(EXCEPTION) << "For SparseGetAttr, CNode must have 2 inputs (Prim, Sparse)";
|
||||
}
|
||||
AbstractCSRToAbstractTuple(inputs[kSparseAttrIndex]);
|
||||
auto index_node = NewValueNode(index);
|
||||
AbstractBasePtr index_abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
|
||||
index_node->set_abstract(index_abs);
|
||||
auto new_node =
|
||||
NewCNode({NewValueNode(prim::kPrimTupleGetItem), inputs[kSparseAttrIndex], index_node}, cnode->func_graph());
|
||||
new_node->set_abstract(node->abstract());
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(cnode, new_node);
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
|
||||
CNodePtr FetchInputsForSparseOP(const AnfNodePtr &node, const KernelGraphPtr &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
if (cnode->GetAttr("has_been_split") != nullptr) {
|
||||
MS_LOG(INFO) << "Do not process CNode " << cnode << " (" << cnode->DebugString() << "), because it has been split.";
|
||||
return nullptr;
|
||||
}
|
||||
const auto &inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
new_inputs.push_back(inputs[0]);
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->isa<CNode>()) {
|
||||
if (SplitCNode(inputs[i], &new_inputs)) continue;
|
||||
} else if (inputs[i]->isa<ValueNode>()) {
|
||||
if (SplitValueNode(inputs[i], &new_inputs, kernel_graph)) continue;
|
||||
} else if (inputs[i]->isa<Parameter>()) {
|
||||
// 1. Split CSRTensor param to multiple tensors.
|
||||
// 2. Set CSRTensor abstract to AbstractTensor that is related its values.
|
||||
if (SplitParameter(inputs[i], &new_inputs, kernel_graph)) continue;
|
||||
}
|
||||
new_inputs.push_back(inputs[i]);
|
||||
}
|
||||
auto new_node = NewCNode(new_inputs, cnode->func_graph());
|
||||
new_node->set_abstract(node->abstract());
|
||||
// Set attr "has_been_split" to prevent the node is split more than once.
|
||||
new_node->AddAttr("has_been_split", MakeValue(true));
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(cnode, new_node);
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
|
||||
const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -143,56 +244,13 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::string prim_name = prim->name();
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
// cnode is a MakeSparse node
|
||||
if (make_sparse_set.find(prim_name) != make_sparse_set.end()) {
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
(void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
(void)inputs.insert(inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
auto new_node = cnode->func_graph()->NewCNode(inputs);
|
||||
std::vector<AbstractBasePtr> abstract_list = GetAbstractList(node, prim_name);
|
||||
auto abs_res = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
new_node->set_abstract(abs_res);
|
||||
new_node->set_scope(cnode->scope());
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(cnode, new_node);
|
||||
}
|
||||
return new_node;
|
||||
// cnode is a SparseGetAttr node
|
||||
return ConvertMakeSparseToMakeTuple(node, kernel_graph);
|
||||
} else if (sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
|
||||
const auto &inputs = cnode->inputs();
|
||||
// Inputs should be [sparse_getattr, sparse]
|
||||
if (inputs.size() <= 1) {
|
||||
MS_LOG_EXCEPTION << "For SparseGetAttr, CNode must have 2 inputs (Prim, Sparse)";
|
||||
}
|
||||
constexpr size_t sparse_index = 1;
|
||||
AbstractCSRToAbstractTuple(inputs[sparse_index]);
|
||||
int64_t index = sparse_attr_map.at(prim_name);
|
||||
auto cons_node = NewValueNode(index);
|
||||
AbstractBasePtr aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
|
||||
cons_node->set_abstract(aptr);
|
||||
auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), inputs[sparse_index], cons_node}, func_graph);
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
// ComputeSparse node: SparseTensorDenseMatmul, CSRMul, CSRReduceSum
|
||||
return ConvertSparseGetAttrToTupleGetItem(sparse_attr_map.at(prim_name), node, kernel_graph);
|
||||
} else if (sparse_op_set.find(prim_name) != sparse_op_set.end()) {
|
||||
const auto &inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
new_inputs.push_back(inputs[0]);
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->isa<CNode>()) {
|
||||
if (SplitCNode(inputs[i], &new_inputs)) continue;
|
||||
} else if (inputs[i]->isa<ValueNode>()) {
|
||||
if (SplitValueNode(inputs[i], &new_inputs, kernel_graph)) continue;
|
||||
} else if (inputs[i]->isa<Parameter>()) {
|
||||
if (SplitParameter(inputs[i], &new_inputs, kernel_graph)) continue;
|
||||
}
|
||||
new_inputs.push_back(inputs[i]);
|
||||
}
|
||||
auto new_node = cnode->func_graph()->NewCNode(new_inputs);
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
return FetchInputsForSparseOP(node, kernel_graph);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -23,6 +23,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// Process SparseOPs:
|
||||
// 1. Convert "MakeCSRTensor/MakeCOOTensor/..." to MakeTuple
|
||||
// 2. Convert "CSRTensorGetIndptr/..." to TupleGetItem
|
||||
// 3. Process inputs for SparseOPs, e.g., split CSRTensor input to multiple tensor inputs.
|
||||
class SparseProcess : public PatternProcessPass {
|
||||
public:
|
||||
explicit SparseProcess(bool multigraph = true) : PatternProcessPass("sparse_process", multigraph) {}
|
||||
|
|
|
@ -225,6 +225,11 @@ class COMMON_EXPORT AnfAlgo {
|
|||
// Get the real output node and indexes of get item, make tuple, depend, load.
|
||||
static AnfNodePtr GetTupleIndexes(const AnfNodePtr &node, std::vector<size_t> *index_stack);
|
||||
static bool IsNopNode(const AnfNodePtr &node);
|
||||
static bool CheckAbsCSRTensor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->abstract());
|
||||
return node->abstract()->isa<abstract::AbstractCSRTensor>();
|
||||
}
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -111,6 +111,17 @@ void ValueTupleToValue(const ValuePtr &value, std::vector<ValuePtr> *const value
|
|||
(void)values->emplace_back(element);
|
||||
}
|
||||
}
|
||||
} else if (value->isa<tensor::CSRTensor>()) {
|
||||
auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->GetIndptr());
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->GetIndices());
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->GetValues());
|
||||
values->emplace_back(csr_tensor->GetIndptr());
|
||||
values->emplace_back(csr_tensor->GetIndices());
|
||||
values->emplace_back(csr_tensor->GetValues());
|
||||
std::transform(csr_tensor->shape().begin(), csr_tensor->shape().end(), std::back_inserter(*values),
|
||||
[](int64_t n) { return std::make_shared<Int64Imm>(n); });
|
||||
} else {
|
||||
(void)values->emplace_back(value);
|
||||
}
|
||||
|
|
|
@ -322,7 +322,12 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
|
|||
}
|
||||
|
||||
// Set type.
|
||||
TypeId type_id = common::AnfAlgo::GetOutputInferDataType(node, front_node_with_index.second);
|
||||
TypeId type_id = kTypeUnknown;
|
||||
if (common::AnfAlgo::CheckAbsCSRTensor(node)) {
|
||||
type_id = node->abstract()->cast<abstract::AbstractCSRTensorPtr>()->GetTypeIdAt(front_node_with_index.second);
|
||||
} else {
|
||||
type_id = common::AnfAlgo::GetOutputInferDataType(node, front_node_with_index.second);
|
||||
}
|
||||
if (builder->GetAllOutputDeviceTypes().size() > front_node_with_index.second) {
|
||||
builder->SetOutputDeviceType(type_id, front_node_with_index.second);
|
||||
} else {
|
||||
|
@ -334,7 +339,12 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
|
|||
}
|
||||
|
||||
// Create device tensor.
|
||||
size_t size = AnfAlgo::GetOutputTensorMemSize(node, front_node_with_index.second);
|
||||
size_t size = 0;
|
||||
if (common::AnfAlgo::CheckAbsCSRTensor(node)) {
|
||||
size = node->cast<ValueNodePtr>()->value()->cast<tensor::CSRTensorPtr>()->GetSizeAt(front_node_with_index.second);
|
||||
} else {
|
||||
size = AnfAlgo::GetOutputTensorMemSize(node, front_node_with_index.second);
|
||||
}
|
||||
device::DeviceAddressPtr address =
|
||||
device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id, ShapeVector());
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
|
@ -570,6 +580,10 @@ KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const
|
|||
if (front_node != nullptr) {
|
||||
MS_LOG(DEBUG) << "Front node:" << front_node->DebugString() << " index:0"
|
||||
<< " for backend node:" << backend_node->DebugString();
|
||||
// Adapt CSRTensor to new runtime
|
||||
if (common::AnfAlgo::CheckAbsCSRTensor(front_node)) {
|
||||
return {front_node, kCsrTensorValuesIndex};
|
||||
}
|
||||
return {front_node, 0};
|
||||
}
|
||||
const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);
|
||||
|
|
|
@ -59,6 +59,7 @@ constexpr size_t kCsrTensorIndPtrIndex = 0;
|
|||
constexpr size_t kCsrTensorIndicesIndex = 1;
|
||||
constexpr size_t kCsrTensorValuesIndex = 2;
|
||||
constexpr size_t kCsrTensorDenseShapeIndex = 3;
|
||||
constexpr size_t kCsrParamOutputSize = 3;
|
||||
constexpr size_t kCooTensorIndicesIndex = 0;
|
||||
constexpr size_t kCooTensorValuesIndex = 1;
|
||||
constexpr size_t kCooTensorDenseShapeIndex = 2;
|
||||
|
|
|
@ -1493,6 +1493,13 @@ void ControlNodeScheduler::LinkDataArrowForKernelActor(const GraphCompilerInfo &
|
|||
}
|
||||
}
|
||||
|
||||
void SetCSRTensorIndex(AnfNodePtr front_node, KernelWithIndex *front_node_with_index) {
|
||||
if (front_node != nullptr && common::AnfAlgo::CheckAbsCSRTensor(front_node)) {
|
||||
MS_EXCEPTION_IF_NULL(front_node_with_index);
|
||||
front_node_with_index->second = kCsrTensorValuesIndex;
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &graph, ControlActor *const entrance_actor,
|
||||
const ControlNodeParserPtr &parser) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -1533,6 +1540,9 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
|
|||
from_node_with_index = tuple_node_with_index;
|
||||
}
|
||||
|
||||
// Adapt CSRTensor to new runtime
|
||||
SetCSRTensorIndex(front_node, &from_node_with_index);
|
||||
|
||||
if (common::AnfAlgo::CheckPrimitiveType(from_node_with_index.first, prim::kPrimTupleGetItem)) {
|
||||
MS_LOG(WARNING) << "Input node:" << from_node_with_index.first->DebugString()
|
||||
<< " for graph:" << graph->ToString() << " is a tuple get item";
|
||||
|
|
|
@ -58,6 +58,33 @@ bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePt
|
|||
return false;
|
||||
}
|
||||
|
||||
void SetCSRParamAddr(const AnfNodePtr &node, size_t output_size, const DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->abstract());
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
||||
auto abs_csr_tensor = node->abstract()->cast<abstract::AbstractCSRTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_csr_tensor);
|
||||
for (size_t i = 0; i < output_size; ++i) {
|
||||
auto abs_tensor = abs_csr_tensor->GetAbsTensorAt(i);
|
||||
MS_EXCEPTION_IF_NULL(abs_tensor);
|
||||
|
||||
TypeId type_id = abs_tensor->BuildType()->type_id();
|
||||
ShapeVector shape_vec = abs_tensor->shape()->shape();
|
||||
std::vector<size_t> res;
|
||||
std::transform(res.begin(), res.end(), std::back_inserter(shape_vec),
|
||||
[](int64_t num) { return static_cast<size_t>(num); });
|
||||
size_t type_size = GetTypeByte(TypeIdToType(type_id));
|
||||
size_t tensor_size = std::accumulate(res.begin(), res.end(), type_size, std::multiplies<size_t>());
|
||||
|
||||
auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, type_id,
|
||||
trans::GetRuntimePaddingShape(node, i));
|
||||
device_address->set_from_persistent_mem(node->isa<Parameter>());
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(node) << " addr:" << device_address;
|
||||
AnfAlgo::SetOutputAddr(device_address, i, node.get());
|
||||
}
|
||||
}
|
||||
|
||||
void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -93,6 +120,10 @@ void CreateParameterDeviceAddress(const DeviceContext *device_context, const Ker
|
|||
// Create device address for anf node in nodes_list
|
||||
for (const auto &item : nodes_list) {
|
||||
auto output_size = common::AnfAlgo::GetOutputTensorNum(item);
|
||||
if (common::AnfAlgo::CheckAbsCSRTensor(item)) {
|
||||
SetCSRParamAddr(item, kCsrParamOutputSize, device_context);
|
||||
continue;
|
||||
}
|
||||
for (size_t index = 0; index < output_size; index++) {
|
||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
|
||||
if (output_type_id == kTypeUnknown) {
|
||||
|
|
|
@ -156,10 +156,9 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
|
|||
}
|
||||
|
||||
// If the node is a call, the outputs num should get from the abstract.
|
||||
if (AnfAlgo::IsCallNode(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
auto abstract = node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
outputs_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
if (AnfAlgo::IsCallNode(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
|
||||
AnfAlgo::CheckAbsCSRTensor(node)) {
|
||||
outputs_num = AnfAlgo::GetOutputNumByAbstract(node->abstract());
|
||||
}
|
||||
|
||||
// The output may be the tuple of node, so need visit all the outputs of node.
|
||||
|
|
|
@ -1848,6 +1848,36 @@ std::string AbstractCSRTensor::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
const AbstractTensorPtr AbstractCSRTensor::GetAbsTensorAt(size_t index) const {
|
||||
if (index == kIndptrIdx) {
|
||||
return indptr_;
|
||||
} else if (index == kIndicesIdx) {
|
||||
return indices_;
|
||||
} else if (index == kValuesIdx) {
|
||||
return values_;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Invalid index: " << index << " for abstract: " << ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const TypeId AbstractCSRTensor::GetTypeIdAt(size_t index) const {
|
||||
if (index == kIndptrIdx) {
|
||||
MS_EXCEPTION_IF_NULL(indptr_);
|
||||
return indptr_->element()->BuildType()->type_id();
|
||||
} else if (index == kIndicesIdx) {
|
||||
MS_EXCEPTION_IF_NULL(indices_);
|
||||
return indices_->element()->BuildType()->type_id();
|
||||
} else if (index == kValuesIdx) {
|
||||
MS_EXCEPTION_IF_NULL(values_);
|
||||
return values_->element()->BuildType()->type_id();
|
||||
} else if (index >= kShapeIdx && index < kShapeIdx + dense_shape_->elements().size()) {
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_);
|
||||
return dense_shape_->elements()[index - kShapeIdx]->BuildType()->type_id();
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Invalid index: " << index << " for abstract: " << ToString();
|
||||
return kTypeUnknown;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
if (!other->isa<AbstractUMonad>()) {
|
||||
|
|
|
@ -1596,7 +1596,8 @@ class MS_CORE_API AbstractCSRTensor : public AbstractUndetermined {
|
|||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
AbstractBasePtr BroadenWithShape() const;
|
||||
|
||||
const AbstractTensorPtr GetAbsTensorAt(size_t index) const;
|
||||
const TypeId GetTypeIdAt(size_t index) const;
|
||||
std::string ToString() const override;
|
||||
|
||||
private:
|
||||
|
@ -1604,6 +1605,10 @@ class MS_CORE_API AbstractCSRTensor : public AbstractUndetermined {
|
|||
AbstractTensorPtr indices_;
|
||||
AbstractTensorPtr values_;
|
||||
AbstractTuplePtr dense_shape_;
|
||||
static constexpr size_t kIndptrIdx = 0;
|
||||
static constexpr size_t kIndicesIdx = 1;
|
||||
static constexpr size_t kValuesIdx = 2;
|
||||
static constexpr size_t kShapeIdx = 3;
|
||||
};
|
||||
using AbstractCSRTensorPtr = std::shared_ptr<AbstractCSRTensor>;
|
||||
|
||||
|
|
|
@ -729,6 +729,23 @@ abstract::AbstractBasePtr CSRTensor::ToAbstract() {
|
|||
return abs_csr_tensor;
|
||||
}
|
||||
|
||||
const size_t CSRTensor::GetSizeAt(size_t index) const {
|
||||
if (index == kIndptrIdx) {
|
||||
MS_EXCEPTION_IF_NULL(indptr_);
|
||||
return indptr_->data().nbytes();
|
||||
} else if (index == kIndicesIdx) {
|
||||
MS_EXCEPTION_IF_NULL(indices_);
|
||||
return indices_->data().nbytes();
|
||||
} else if (index == kValuesIdx) {
|
||||
MS_EXCEPTION_IF_NULL(values_);
|
||||
return values_->data().nbytes();
|
||||
} else if (index >= kIndicesIdx && index < kShapeIdx + shape().size()) {
|
||||
return sizeof(int64_t);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Invalid index: " << index << " for CSRTensor: " << ToString();
|
||||
return kTypeUnknown;
|
||||
}
|
||||
|
||||
std::string COOTensor::ToString() const {
|
||||
std::ostringstream buf;
|
||||
MS_EXCEPTION_IF_NULL(indices_);
|
||||
|
|
|
@ -614,6 +614,8 @@ class MS_CORE_API CSRTensor : public MetaSparseTensor {
|
|||
return false;
|
||||
}
|
||||
|
||||
const size_t GetSizeAt(size_t index) const;
|
||||
|
||||
/// \brief Get display information of this Tensor.
|
||||
///
|
||||
/// \return The display information of this Tensor.
|
||||
|
@ -623,6 +625,10 @@ class MS_CORE_API CSRTensor : public MetaSparseTensor {
|
|||
TensorPtr indptr_;
|
||||
TensorPtr indices_;
|
||||
TensorPtr values_;
|
||||
static constexpr size_t kIndptrIdx = 0;
|
||||
static constexpr size_t kIndicesIdx = 1;
|
||||
static constexpr size_t kValuesIdx = 2;
|
||||
static constexpr size_t kShapeIdx = 3;
|
||||
};
|
||||
using CSRTensorPtr = std::shared_ptr<CSRTensor>;
|
||||
|
||||
|
|
|
@ -329,6 +329,10 @@ size_t AnfUtils::GetOutputTensorNum(const AnfNodePtr &node) {
|
|||
res = tuple_type->size();
|
||||
} else if (type->isa<TypeNone>()) {
|
||||
res = 0;
|
||||
} else if (type->isa<CSRTensorType>()) {
|
||||
// Currently, CSRTensor only supports 2-D matrix (shape has 2 values). 5 outputs = 3 Tensors + 2 shape values.
|
||||
constexpr size_t kCSRTensorOutputNum = 5;
|
||||
res = kCSRTensorOutputNum;
|
||||
} else {
|
||||
res = 1;
|
||||
}
|
||||
|
|
|
@ -25,6 +25,9 @@ csr_mv_op_info = AkgGpuRegOp("CSRMV") \
|
|||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(csr_mv_op_info)
|
||||
|
|
Loading…
Reference in New Issue