!26811 Support CSRTensor in while loop and subgraph

Merge pull request !26811 from 杨林枫/csr_in_while
This commit is contained in:
i-robot 2021-12-01 06:38:36 +00:00 committed by Gitee
commit c5fac5aba4
5 changed files with 91 additions and 7 deletions

View File

@ -26,6 +26,24 @@
namespace mindspore {
namespace opt {
// Convert CSRTensor Parameter or ValueNode to Tuple by setting its abstract.
void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) {
MS_EXCEPTION_IF_NULL(sparse);
if (!(sparse->isa<Parameter>() || sparse->isa<ValueNode>())) {
return;
}
auto param_abs = sparse->abstract();
MS_EXCEPTION_IF_NULL(param_abs);
if (param_abs->isa<abstract::AbstractCSRTensor>()) {
auto abs_sparse = param_abs->cast<abstract::AbstractCSRTensorPtr>();
std::vector<AbstractBasePtr> abstract_list{abs_sparse->indptr(), abs_sparse->indices(), abs_sparse->values(),
abs_sparse->dense_shape()};
auto abs_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
abs_tuple->set_type(abs_tuple->BuildType());
sparse->set_abstract(abs_tuple);
}
}
const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
@ -60,14 +78,16 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
} else if (sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
const auto &inputs = cnode->inputs();
// Inputs should be [sparse_getattr, sparse]
if (inputs.size() <= 1) {
MS_LOG_EXCEPTION << "For SparseGetAttr, CNode must have 2 inputs (Prim, Sparse)";
}
constexpr size_t sparse_index = 1;
AnfNodePtr sparse = inputs[sparse_index];
MS_EXCEPTION_IF_NULL(sparse);
AbstractCSRToAbstractTuple(inputs[sparse_index]);
int64_t index = sparse_attr_map.at(prim_name);
auto cons_node = NewValueNode(index);
AbstractBasePtr aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
cons_node->set_abstract(aptr);
auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node}, func_graph);
auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), inputs[sparse_index], cons_node}, func_graph);
new_node->set_abstract(node->abstract());
return new_node;
}

View File

@ -170,7 +170,7 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector<
void CPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto &input_nodes = kernel_graph->inputs();
auto &input_nodes = kernel_graph->input_nodes();
if (input_nodes.size() != inputs_const.size()) {
MS_LOG(EXCEPTION) << "Input size " << inputs_const.size() << " is not equal to input node size "
<< input_nodes.size();

View File

@ -130,7 +130,7 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
for (auto &item : kernel_graph->inputs()) {
for (auto &item : kernel_graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(item);
if (item->isa<Parameter>()) {
auto output_num = AnfAlgo::GetOutputTensorNum(item);
@ -281,7 +281,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph,
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(tensor_to_node);
auto &input_nodes = kernel_graph->inputs();
auto &input_nodes = kernel_graph->input_nodes();
if (input_nodes.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Input size " << inputs.size() << " is not equal to input node size " << input_nodes.size();
}
@ -305,7 +305,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph,
void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs) {
auto &input_nodes = kernel_graph.inputs();
auto &input_nodes = kernel_graph.input_nodes();
if (input_nodes.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Input size" << inputs.size() << " is not equal to input node size " << input_nodes.size();
}

View File

@ -1516,6 +1516,7 @@ class MS_CORE_API AbstractCSRTensor : public AbstractUndetermined {
AbstractTensorPtr values_;
AbstractTuplePtr dense_shape_;
};
using AbstractCSRTensorPtr = std::shared_ptr<AbstractCSRTensor>;
class AbstractMonad : public AbstractBase {
public:

View File

@ -15,9 +15,13 @@
"""smoke tests for CSR operations"""
import pytest
import numpy as np
from mindspore import Tensor, CSRTensor, ms_function
from mindspore.common import dtype as mstype
from mindspore import nn, context
context.set_context(mode=context.GRAPH_MODE)
def compare_csr(csr1, csr2):
assert isinstance(csr1, CSRTensor)
@ -82,3 +86,62 @@ def test_csr_attr():
csr1 = CSRTensor(*csr1_tuple)
csr2 = CSRTensor(*csr2_tuple)
compare_csr(csr1, csr2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_csr_tensor_in_while():
"""
Feature: Test CSRTensor in while loop.
Description: Test CSRTensor computation in while loop.
Expectation: Success.
"""
class CSRTensorValuesDouble(nn.Cell):
def construct(self, x):
indptr = x.indptr
indices = x.indices
values = x.values * 2
shape = x.shape
return CSRTensor(indptr, indices, values, shape)
class CSRTensorValuesAdd2(nn.Cell):
def construct(self, x):
indptr = x.indptr
indices = x.indices
values = x.values + 2
shape = x.shape
return CSRTensor(indptr, indices, values, shape)
class CSRTensorWithControlWhile(nn.Cell):
def __init__(self, shape):
super().__init__()
self.op1 = CSRTensorValuesDouble()
self.op2 = CSRTensorValuesAdd2()
self.shape = shape
@ms_function
def construct(self, a, b, indptr, indices, values):
x = CSRTensor(indptr, indices, values, self.shape)
x = self.op2(x)
while a > b:
x = self.op1(x)
b = b + 1
return x
a = Tensor(3, mstype.int32)
b = Tensor(0, mstype.int32)
indptr = Tensor([0, 1, 2])
indices = Tensor([0, 1])
values = Tensor([1, 2], dtype=mstype.float32)
shape = (2, 6)
net = CSRTensorWithControlWhile(shape)
out = net(a, b, indptr, indices, values)
assert np.allclose(out.indptr.asnumpy(), indptr.asnumpy(), .0, .0)
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0)
assert shape == out.shape