!26811 Support CSRTensor in while loop and subgraph
Merge pull request !26811 from 杨林枫/csr_in_while
This commit is contained in:
commit
c5fac5aba4
|
@ -26,6 +26,24 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
// Convert CSRTensor Parameter or ValueNode to Tuple by setting its abstract.
|
||||||
|
void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) {
|
||||||
|
MS_EXCEPTION_IF_NULL(sparse);
|
||||||
|
if (!(sparse->isa<Parameter>() || sparse->isa<ValueNode>())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto param_abs = sparse->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(param_abs);
|
||||||
|
if (param_abs->isa<abstract::AbstractCSRTensor>()) {
|
||||||
|
auto abs_sparse = param_abs->cast<abstract::AbstractCSRTensorPtr>();
|
||||||
|
std::vector<AbstractBasePtr> abstract_list{abs_sparse->indptr(), abs_sparse->indices(), abs_sparse->values(),
|
||||||
|
abs_sparse->dense_shape()};
|
||||||
|
auto abs_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||||
|
abs_tuple->set_type(abs_tuple->BuildType());
|
||||||
|
sparse->set_abstract(abs_tuple);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
const EquivPtr &) const {
|
const EquivPtr &) const {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
@ -60,14 +78,16 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
|
||||||
} else if (sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
|
} else if (sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
|
||||||
const auto &inputs = cnode->inputs();
|
const auto &inputs = cnode->inputs();
|
||||||
// Inputs should be [sparse_getattr, sparse]
|
// Inputs should be [sparse_getattr, sparse]
|
||||||
|
if (inputs.size() <= 1) {
|
||||||
|
MS_LOG_EXCEPTION << "For SparseGetAttr, CNode must have 2 inputs (Prim, Sparse)";
|
||||||
|
}
|
||||||
constexpr size_t sparse_index = 1;
|
constexpr size_t sparse_index = 1;
|
||||||
AnfNodePtr sparse = inputs[sparse_index];
|
AbstractCSRToAbstractTuple(inputs[sparse_index]);
|
||||||
MS_EXCEPTION_IF_NULL(sparse);
|
|
||||||
int64_t index = sparse_attr_map.at(prim_name);
|
int64_t index = sparse_attr_map.at(prim_name);
|
||||||
auto cons_node = NewValueNode(index);
|
auto cons_node = NewValueNode(index);
|
||||||
AbstractBasePtr aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
|
AbstractBasePtr aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
|
||||||
cons_node->set_abstract(aptr);
|
cons_node->set_abstract(aptr);
|
||||||
auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node}, func_graph);
|
auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), inputs[sparse_index], cons_node}, func_graph);
|
||||||
new_node->set_abstract(node->abstract());
|
new_node->set_abstract(node->abstract());
|
||||||
return new_node;
|
return new_node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,7 +170,7 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector<
|
||||||
void CPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
void CPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||||
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
auto &input_nodes = kernel_graph->inputs();
|
auto &input_nodes = kernel_graph->input_nodes();
|
||||||
if (input_nodes.size() != inputs_const.size()) {
|
if (input_nodes.size() != inputs_const.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Input size " << inputs_const.size() << " is not equal to input node size "
|
MS_LOG(EXCEPTION) << "Input size " << inputs_const.size() << " is not equal to input node size "
|
||||||
<< input_nodes.size();
|
<< input_nodes.size();
|
||||||
|
|
|
@ -130,7 +130,7 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
|
||||||
|
|
||||||
void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) {
|
void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
for (auto &item : kernel_graph->inputs()) {
|
for (auto &item : kernel_graph->input_nodes()) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
if (item->isa<Parameter>()) {
|
if (item->isa<Parameter>()) {
|
||||||
auto output_num = AnfAlgo::GetOutputTensorNum(item);
|
auto output_num = AnfAlgo::GetOutputTensorNum(item);
|
||||||
|
@ -281,7 +281,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph,
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
MS_EXCEPTION_IF_NULL(outputs);
|
MS_EXCEPTION_IF_NULL(outputs);
|
||||||
MS_EXCEPTION_IF_NULL(tensor_to_node);
|
MS_EXCEPTION_IF_NULL(tensor_to_node);
|
||||||
auto &input_nodes = kernel_graph->inputs();
|
auto &input_nodes = kernel_graph->input_nodes();
|
||||||
if (input_nodes.size() != inputs.size()) {
|
if (input_nodes.size() != inputs.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Input size " << inputs.size() << " is not equal to input node size " << input_nodes.size();
|
MS_LOG(EXCEPTION) << "Input size " << inputs.size() << " is not equal to input node size " << input_nodes.size();
|
||||||
}
|
}
|
||||||
|
@ -305,7 +305,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph,
|
||||||
|
|
||||||
void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &kernel_graph,
|
void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &kernel_graph,
|
||||||
const std::vector<tensor::TensorPtr> &inputs) {
|
const std::vector<tensor::TensorPtr> &inputs) {
|
||||||
auto &input_nodes = kernel_graph.inputs();
|
auto &input_nodes = kernel_graph.input_nodes();
|
||||||
if (input_nodes.size() != inputs.size()) {
|
if (input_nodes.size() != inputs.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Input size" << inputs.size() << " is not equal to input node size " << input_nodes.size();
|
MS_LOG(EXCEPTION) << "Input size" << inputs.size() << " is not equal to input node size " << input_nodes.size();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1516,6 +1516,7 @@ class MS_CORE_API AbstractCSRTensor : public AbstractUndetermined {
|
||||||
AbstractTensorPtr values_;
|
AbstractTensorPtr values_;
|
||||||
AbstractTuplePtr dense_shape_;
|
AbstractTuplePtr dense_shape_;
|
||||||
};
|
};
|
||||||
|
using AbstractCSRTensorPtr = std::shared_ptr<AbstractCSRTensor>;
|
||||||
|
|
||||||
class AbstractMonad : public AbstractBase {
|
class AbstractMonad : public AbstractBase {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -15,9 +15,13 @@
|
||||||
"""smoke tests for CSR operations"""
|
"""smoke tests for CSR operations"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import Tensor, CSRTensor, ms_function
|
from mindspore import Tensor, CSRTensor, ms_function
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore import nn, context
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
def compare_csr(csr1, csr2):
|
def compare_csr(csr1, csr2):
|
||||||
assert isinstance(csr1, CSRTensor)
|
assert isinstance(csr1, CSRTensor)
|
||||||
|
@ -82,3 +86,62 @@ def test_csr_attr():
|
||||||
csr1 = CSRTensor(*csr1_tuple)
|
csr1 = CSRTensor(*csr1_tuple)
|
||||||
csr2 = CSRTensor(*csr2_tuple)
|
csr2 = CSRTensor(*csr2_tuple)
|
||||||
compare_csr(csr1, csr2)
|
compare_csr(csr1, csr2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_csr_tensor_in_while():
|
||||||
|
"""
|
||||||
|
Feature: Test CSRTensor in while loop.
|
||||||
|
Description: Test CSRTensor computation in while loop.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
class CSRTensorValuesDouble(nn.Cell):
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
indptr = x.indptr
|
||||||
|
indices = x.indices
|
||||||
|
values = x.values * 2
|
||||||
|
shape = x.shape
|
||||||
|
return CSRTensor(indptr, indices, values, shape)
|
||||||
|
|
||||||
|
class CSRTensorValuesAdd2(nn.Cell):
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
indptr = x.indptr
|
||||||
|
indices = x.indices
|
||||||
|
values = x.values + 2
|
||||||
|
shape = x.shape
|
||||||
|
return CSRTensor(indptr, indices, values, shape)
|
||||||
|
|
||||||
|
class CSRTensorWithControlWhile(nn.Cell):
|
||||||
|
def __init__(self, shape):
|
||||||
|
super().__init__()
|
||||||
|
self.op1 = CSRTensorValuesDouble()
|
||||||
|
self.op2 = CSRTensorValuesAdd2()
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, a, b, indptr, indices, values):
|
||||||
|
x = CSRTensor(indptr, indices, values, self.shape)
|
||||||
|
x = self.op2(x)
|
||||||
|
while a > b:
|
||||||
|
x = self.op1(x)
|
||||||
|
b = b + 1
|
||||||
|
return x
|
||||||
|
a = Tensor(3, mstype.int32)
|
||||||
|
b = Tensor(0, mstype.int32)
|
||||||
|
indptr = Tensor([0, 1, 2])
|
||||||
|
indices = Tensor([0, 1])
|
||||||
|
values = Tensor([1, 2], dtype=mstype.float32)
|
||||||
|
shape = (2, 6)
|
||||||
|
net = CSRTensorWithControlWhile(shape)
|
||||||
|
out = net(a, b, indptr, indices, values)
|
||||||
|
assert np.allclose(out.indptr.asnumpy(), indptr.asnumpy(), .0, .0)
|
||||||
|
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
|
||||||
|
assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0)
|
||||||
|
assert shape == out.shape
|
||||||
|
|
Loading…
Reference in New Issue