From 343b17c61fc9c6e7ff1f1ca8b956fa504fba0d47 Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Sun, 21 Nov 2021 21:56:08 +0800 Subject: [PATCH] support csr in while loop --- .../backend/optimizer/pass/sparse_process.cc | 26 +++++++- .../ccsrc/backend/session/cpu_session.cc | 2 +- .../runtime/device/cpu/cpu_kernel_runtime.cc | 6 +- mindspore/core/abstract/abstract_value.h | 1 + tests/st/sparse/test_csr.py | 63 +++++++++++++++++++ 5 files changed, 91 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/pass/sparse_process.cc b/mindspore/ccsrc/backend/optimizer/pass/sparse_process.cc index e4f9a0adae6..d5ce6ecae57 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/sparse_process.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/sparse_process.cc @@ -26,6 +26,24 @@ namespace mindspore { namespace opt { +// Convert CSRTensor Parameter or ValueNode to Tuple by setting its abstract. +void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) { + MS_EXCEPTION_IF_NULL(sparse); + if (!(sparse->isa() || sparse->isa())) { + return; + } + auto param_abs = sparse->abstract(); + MS_EXCEPTION_IF_NULL(param_abs); + if (param_abs->isa()) { + auto abs_sparse = param_abs->cast(); + std::vector abstract_list{abs_sparse->indptr(), abs_sparse->indices(), abs_sparse->values(), + abs_sparse->dense_shape()}; + auto abs_tuple = std::make_shared(abstract_list); + abs_tuple->set_type(abs_tuple->BuildType()); + sparse->set_abstract(abs_tuple); + } +} + const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); @@ -60,14 +78,16 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An } else if (sparse_attr_map.find(prim_name) != sparse_attr_map.end()) { const auto &inputs = cnode->inputs(); // Inputs should be [sparse_getattr, sparse] + if (inputs.size() <= 1) { + MS_LOG_EXCEPTION << "For SparseGetAttr, CNode must have 2 inputs (Prim, Sparse)"; + } constexpr size_t sparse_index = 1; - AnfNodePtr sparse = inputs[sparse_index]; - MS_EXCEPTION_IF_NULL(sparse); + AbstractCSRToAbstractTuple(inputs[sparse_index]); int64_t index = sparse_attr_map.at(prim_name); auto cons_node = NewValueNode(index); AbstractBasePtr aptr = std::make_shared(std::make_shared(index)); cons_node->set_abstract(aptr); - auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node}, func_graph); + auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), inputs[sparse_index], cons_node}, func_graph); new_node->set_abstract(node->abstract()); return new_node; } diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index 064a971e771..3a8689d7d45 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -170,7 +170,7 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector< void CPUSession::LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const { MS_EXCEPTION_IF_NULL(kernel_graph); - auto &input_nodes = kernel_graph->inputs(); + auto &input_nodes = kernel_graph->input_nodes(); if (input_nodes.size() != inputs_const.size()) { MS_LOG(EXCEPTION) << "Input size " << inputs_const.size() << " is not equal to input node size " << input_nodes.size(); diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index 038fcca8a69..d9629c0526e 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -130,7 +130,7 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); - for (auto &item : kernel_graph->inputs()) { + for (auto &item : kernel_graph->input_nodes()) { MS_EXCEPTION_IF_NULL(item); if (item->isa()) { auto output_num = AnfAlgo::GetOutputTensorNum(item); @@ -281,7 +281,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph, MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(tensor_to_node); - auto &input_nodes = kernel_graph->inputs(); + auto &input_nodes = kernel_graph->input_nodes(); if (input_nodes.size() != inputs.size()) { MS_LOG(EXCEPTION) << "Input size " << inputs.size() << " is not equal to input node size " << input_nodes.size(); } @@ -305,7 +305,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph, void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &kernel_graph, const std::vector &inputs) { - auto &input_nodes = kernel_graph.inputs(); + auto &input_nodes = kernel_graph.input_nodes(); if (input_nodes.size() != inputs.size()) { MS_LOG(EXCEPTION) << "Input size" << inputs.size() << " is not equal to input node size " << input_nodes.size(); } diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index 36321e02541..d63c96ef532 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -1516,6 +1516,7 @@ class MS_CORE_API AbstractCSRTensor : public AbstractUndetermined { AbstractTensorPtr values_; AbstractTuplePtr dense_shape_; }; +using AbstractCSRTensorPtr = std::shared_ptr; class AbstractMonad : public AbstractBase { public: diff --git a/tests/st/sparse/test_csr.py b/tests/st/sparse/test_csr.py index 9ce2951ec2d..b023ac9958f 100644 --- a/tests/st/sparse/test_csr.py +++ b/tests/st/sparse/test_csr.py @@ -15,9 +15,13 @@ """smoke tests for CSR operations""" import pytest +import numpy as np + from mindspore import Tensor, CSRTensor, ms_function from mindspore.common import dtype as mstype +from mindspore import nn, context +context.set_context(mode=context.GRAPH_MODE) def compare_csr(csr1, csr2): assert isinstance(csr1, CSRTensor) @@ -82,3 +86,62 @@ def test_csr_attr(): csr1 = CSRTensor(*csr1_tuple) csr2 = CSRTensor(*csr2_tuple) compare_csr(csr1, csr2) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_csr_tensor_in_while(): + """ + Feature: Test CSRTensor in while loop. + Description: Test CSRTensor computation in while loop. + Expectation: Success. + """ + class CSRTensorValuesDouble(nn.Cell): + + def construct(self, x): + indptr = x.indptr + indices = x.indices + values = x.values * 2 + shape = x.shape + return CSRTensor(indptr, indices, values, shape) + + class CSRTensorValuesAdd2(nn.Cell): + + def construct(self, x): + indptr = x.indptr + indices = x.indices + values = x.values + 2 + shape = x.shape + return CSRTensor(indptr, indices, values, shape) + + class CSRTensorWithControlWhile(nn.Cell): + def __init__(self, shape): + super().__init__() + self.op1 = CSRTensorValuesDouble() + self.op2 = CSRTensorValuesAdd2() + self.shape = shape + + @ms_function + def construct(self, a, b, indptr, indices, values): + x = CSRTensor(indptr, indices, values, self.shape) + x = self.op2(x) + while a > b: + x = self.op1(x) + b = b + 1 + return x + a = Tensor(3, mstype.int32) + b = Tensor(0, mstype.int32) + indptr = Tensor([0, 1, 2]) + indices = Tensor([0, 1]) + values = Tensor([1, 2], dtype=mstype.float32) + shape = (2, 6) + net = CSRTensorWithControlWhile(shape) + out = net(a, b, indptr, indices, values) + assert np.allclose(out.indptr.asnumpy(), indptr.asnumpy(), .0, .0) + assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0) + assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0) + assert shape == out.shape