!26811 Support CSRTensor in while loop and subgraph
Merge pull request !26811 from 杨林枫/csr_in_while
This commit is contained in:
commit
c5fac5aba4
|
@ -26,6 +26,24 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// Convert CSRTensor Parameter or ValueNode to Tuple by setting its abstract.
|
||||
void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) {
|
||||
MS_EXCEPTION_IF_NULL(sparse);
|
||||
if (!(sparse->isa<Parameter>() || sparse->isa<ValueNode>())) {
|
||||
return;
|
||||
}
|
||||
auto param_abs = sparse->abstract();
|
||||
MS_EXCEPTION_IF_NULL(param_abs);
|
||||
if (param_abs->isa<abstract::AbstractCSRTensor>()) {
|
||||
auto abs_sparse = param_abs->cast<abstract::AbstractCSRTensorPtr>();
|
||||
std::vector<AbstractBasePtr> abstract_list{abs_sparse->indptr(), abs_sparse->indices(), abs_sparse->values(),
|
||||
abs_sparse->dense_shape()};
|
||||
auto abs_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
abs_tuple->set_type(abs_tuple->BuildType());
|
||||
sparse->set_abstract(abs_tuple);
|
||||
}
|
||||
}
|
||||
|
||||
const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -60,14 +78,16 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
|
|||
} else if (sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
|
||||
const auto &inputs = cnode->inputs();
|
||||
// Inputs should be [sparse_getattr, sparse]
|
||||
if (inputs.size() <= 1) {
|
||||
MS_LOG_EXCEPTION << "For SparseGetAttr, CNode must have 2 inputs (Prim, Sparse)";
|
||||
}
|
||||
constexpr size_t sparse_index = 1;
|
||||
AnfNodePtr sparse = inputs[sparse_index];
|
||||
MS_EXCEPTION_IF_NULL(sparse);
|
||||
AbstractCSRToAbstractTuple(inputs[sparse_index]);
|
||||
int64_t index = sparse_attr_map.at(prim_name);
|
||||
auto cons_node = NewValueNode(index);
|
||||
AbstractBasePtr aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
|
||||
cons_node->set_abstract(aptr);
|
||||
auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node}, func_graph);
|
||||
auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), inputs[sparse_index], cons_node}, func_graph);
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
}
|
||||
|
|
|
@ -170,7 +170,7 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector<
|
|||
void CPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto &input_nodes = kernel_graph->inputs();
|
||||
auto &input_nodes = kernel_graph->input_nodes();
|
||||
if (input_nodes.size() != inputs_const.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input size " << inputs_const.size() << " is not equal to input node size "
|
||||
<< input_nodes.size();
|
||||
|
|
|
@ -130,7 +130,7 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
|
|||
|
||||
void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
for (auto &item : kernel_graph->inputs()) {
|
||||
for (auto &item : kernel_graph->input_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
if (item->isa<Parameter>()) {
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(item);
|
||||
|
@ -281,7 +281,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph,
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(tensor_to_node);
|
||||
auto &input_nodes = kernel_graph->inputs();
|
||||
auto &input_nodes = kernel_graph->input_nodes();
|
||||
if (input_nodes.size() != inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input size " << inputs.size() << " is not equal to input node size " << input_nodes.size();
|
||||
}
|
||||
|
@ -305,7 +305,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph,
|
|||
|
||||
void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs) {
|
||||
auto &input_nodes = kernel_graph.inputs();
|
||||
auto &input_nodes = kernel_graph.input_nodes();
|
||||
if (input_nodes.size() != inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input size" << inputs.size() << " is not equal to input node size " << input_nodes.size();
|
||||
}
|
||||
|
|
|
@ -1516,6 +1516,7 @@ class MS_CORE_API AbstractCSRTensor : public AbstractUndetermined {
|
|||
AbstractTensorPtr values_;
|
||||
AbstractTuplePtr dense_shape_;
|
||||
};
|
||||
using AbstractCSRTensorPtr = std::shared_ptr<AbstractCSRTensor>;
|
||||
|
||||
class AbstractMonad : public AbstractBase {
|
||||
public:
|
||||
|
|
|
@ -15,9 +15,13 @@
|
|||
"""smoke tests for CSR operations"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, CSRTensor, ms_function
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import nn, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
def compare_csr(csr1, csr2):
|
||||
assert isinstance(csr1, CSRTensor)
|
||||
|
@ -82,3 +86,62 @@ def test_csr_attr():
|
|||
csr1 = CSRTensor(*csr1_tuple)
|
||||
csr2 = CSRTensor(*csr2_tuple)
|
||||
compare_csr(csr1, csr2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_tensor_in_while():
|
||||
"""
|
||||
Feature: Test CSRTensor in while loop.
|
||||
Description: Test CSRTensor computation in while loop.
|
||||
Expectation: Success.
|
||||
"""
|
||||
class CSRTensorValuesDouble(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
indptr = x.indptr
|
||||
indices = x.indices
|
||||
values = x.values * 2
|
||||
shape = x.shape
|
||||
return CSRTensor(indptr, indices, values, shape)
|
||||
|
||||
class CSRTensorValuesAdd2(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
indptr = x.indptr
|
||||
indices = x.indices
|
||||
values = x.values + 2
|
||||
shape = x.shape
|
||||
return CSRTensor(indptr, indices, values, shape)
|
||||
|
||||
class CSRTensorWithControlWhile(nn.Cell):
|
||||
def __init__(self, shape):
|
||||
super().__init__()
|
||||
self.op1 = CSRTensorValuesDouble()
|
||||
self.op2 = CSRTensorValuesAdd2()
|
||||
self.shape = shape
|
||||
|
||||
@ms_function
|
||||
def construct(self, a, b, indptr, indices, values):
|
||||
x = CSRTensor(indptr, indices, values, self.shape)
|
||||
x = self.op2(x)
|
||||
while a > b:
|
||||
x = self.op1(x)
|
||||
b = b + 1
|
||||
return x
|
||||
a = Tensor(3, mstype.int32)
|
||||
b = Tensor(0, mstype.int32)
|
||||
indptr = Tensor([0, 1, 2])
|
||||
indices = Tensor([0, 1])
|
||||
values = Tensor([1, 2], dtype=mstype.float32)
|
||||
shape = (2, 6)
|
||||
net = CSRTensorWithControlWhile(shape)
|
||||
out = net(a, b, indptr, indices, values)
|
||||
assert np.allclose(out.indptr.asnumpy(), indptr.asnumpy(), .0, .0)
|
||||
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
|
||||
assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0)
|
||||
assert shape == out.shape
|
||||
|
|
Loading…
Reference in New Issue