support pynative csr op

This commit is contained in:
yanglf1121 2021-12-10 14:55:16 +08:00
parent 3b3a6da5da
commit feefdae8e3
9 changed files with 65 additions and 52 deletions

View File

@ -52,11 +52,11 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
#endif
auto optimizer = std::make_shared<GraphOptimizer>();
auto common_pm = std::make_shared<PassManager>("common_pm");
common_pm->AddPass(std::make_shared<SparseProcess>());
common_pm->AddPass(std::make_shared<AddDynamicShapeAttr>());
common_pm->AddPass(std::make_shared<ReduceSumOptimizer>());
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());
common_pm->AddPass(std::make_shared<CustomOpConstInputToAttr>());
common_pm->AddPass(std::make_shared<SparseProcess>());
common_pm->AddPass(std::make_shared<ConvertAttrToUnifyMindIR>());
common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>());
common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>());

View File

@ -54,9 +54,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimReduceAny->name(), {1});
Register(prim::kPrimUnsortedSegmentMin->name(), {2});
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
Register(prim::kPrimCSRMul->name(), {3});
Register(prim::kPrimCSRReduceSum->name(), {3, 4});
Register(prim::kPrimCSRMV->name(), {3});
Register(prim::kPrimCSRReduceSum->name(), {1});
Register(kSparseGatherV2OpName, {2});
Register(kUnsortedSegmentProdOpName, {2});
Register(kSimpleMeanGradOpName, {1});

View File

@ -65,9 +65,6 @@ bool SplitValueNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs)
new_inputs->push_back(NewValueNodeAndSetAbstract(csr_tensor->GetIndptr(), csr_abs->indptr()));
new_inputs->push_back(NewValueNodeAndSetAbstract(csr_tensor->GetIndices(), csr_abs->indices()));
new_inputs->push_back(NewValueNodeAndSetAbstract(csr_tensor->GetValues(), csr_abs->values()));
auto shape_node = NewValueNode(csr_tensor->shape());
shape_node->set_abstract(csr_abs->dense_shape());
new_inputs->push_back(shape_node);
return true;
}
@ -80,7 +77,9 @@ bool SplitCNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs) {
return false;
auto sparse_inputs = cnode->inputs();
for (size_t j = 1; j < sparse_inputs.size(); ++j) {
// skip the last input, as it always represents shape, and has already been
// registered as primitive attribute.
for (size_t j = 1; j < sparse_inputs.size() - 1; ++j) {
new_inputs->push_back(sparse_inputs[j]);
}
return true;
@ -107,6 +106,7 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
(void)inputs.insert(inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
auto new_node = cnode->func_graph()->NewCNode(inputs);
auto abs_sparse = dyn_cast<abstract::AbstractCSRTensor>(node->abstract());
MS_EXCEPTION_IF_NULL(abs_sparse);
std::vector<AbstractBasePtr> abstract_list{abs_sparse->indptr(), abs_sparse->indices(), abs_sparse->values(),
abs_sparse->dense_shape()};
auto abs_res = std::make_shared<abstract::AbstractTuple>(abstract_list);

View File

@ -483,6 +483,25 @@ void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tenso
input_tensors->emplace_back(tensor_ptr);
}
void ConvertCSRTensorToTensorList(const py::object &input_object, const PrimitivePtr &op_prim,
std::vector<tensor::TensorPtr> *input_tensors) {
MS_EXCEPTION_IF_NULL(op_prim);
MS_EXCEPTION_IF_NULL(input_tensors);
if (!py::isinstance<tensor::CSRTensor>(input_object)) {
MS_LOG(EXCEPTION) << "The input should be a csr_tensor! ";
}
auto input_names = op_prim->GetAttr(kAttrInputNames);
if (input_names == nullptr) {
MS_LOG(DEBUG) << "input_names are nullptr";
return;
}
auto csr_inputs = py::cast<tensor::CSRTensor>(input_object);
input_tensors->emplace_back(csr_inputs.GetIndptr());
input_tensors->emplace_back(csr_inputs.GetIndices());
input_tensors->emplace_back(csr_inputs.GetValues());
op_prim->set_attr("is_csr", MakeValue(true));
}
void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
std::vector<tensor::TensorPtr> *input_tensors, int64_t *const tensor_mask) {
MS_EXCEPTION_IF_NULL(op_prim);
@ -535,6 +554,9 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr
} else if (py::isinstance<py::tuple>(input_object)) {
ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
return;
} else if (py::isinstance<tensor::CSRTensor>(input_object)) {
ConvertCSRTensorToTensorList(input_object, op_prim, input_tensors);
return;
} else if (py::isinstance<py::none>(input_object)) {
return;
} else {
@ -884,7 +906,8 @@ py::object GetDstType(const TypeId &type_id) {
}
bool IsPyObjTypeInvalid(const py::object &obj) {
return !py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj);
return !py::isinstance<tensor::Tensor>(obj) && !py::isinstance<tensor::CSRTensor>(obj) &&
!py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj);
}
inline bool IsNopPrim(const std::string &op_name) {

View File

@ -90,8 +90,7 @@ const mindspore::HashMap<std::string, int64_t> sparse_attr_map = {{prim::kPrimCS
// make_sparse_set records all make_sparse primitives, and tries to replace
// make_sparse to make_tuple, used in backend common optimization pass:
// sparse_process.cc
const mindspore::HashSet<std::string> make_sparse_set = {{prim::kPrimMakeCSRTensor->name()},
{prim::kPrimMakeSparseTensor->name()}};
const mindspore::HashSet<std::string> make_sparse_set = {{prim::kPrimMakeCSRTensor->name()}};
// sparse_op_set records all sparse_compute operators, which takes sparsetensor
// and (possibly) dense tensors, used in backend common optimization pass:
// sparse_process.cc

View File

@ -2321,6 +2321,10 @@ class CSRTensor(CSRTensor_):
return CSRTensor_.__repr__(self)
return ''
def __mul__(self, other):
res = tensor_operator_registry.get('csr_mul')(self, other)
return CSRTensor(self.indptr, self.indices, res, self.shape)
@property
def indptr(self):
return Tensor(self._indptr)

View File

@ -21,7 +21,7 @@ from ...composite import base
from ...operations._inner_ops import TensorCopySlices, SliceGetItem
from ....common import dtype as mstype
from ....common._register_for_tensor import tensor_operator_registry
from ....common.tensor import Tensor
from ....common.tensor import Tensor, CSRTensor
slice_get_item = SliceGetItem()
hyper_map = base.HyperMap()
@ -96,6 +96,8 @@ def _tensor_sub(self, other):
def _tensor_mul(self, other):
if isinstance(other, (tuple, list)):
other = sequence_to_tensor(other, F.dtype(self))
elif isinstance(other, CSRTensor):
return other * self
return F.mul(self, other)

View File

@ -478,6 +478,7 @@ tensor_operator_registry.register('gather_nd', gather_nd)
tensor_operator_registry.register('stack', P.Stack)
tensor_operator_registry.register('log', log)
tensor_operator_registry.register('floor', floor)
# support sparse tensor operators
tensor_operator_registry.register('csr_mul', csr_mul)
__all__ = [name for name in dir() if name[0] != "_"]
__all__.remove('Primitive')

View File

@ -211,51 +211,37 @@ def test_csr_ops():
Description: Test CSRReduceSum, CSRMul, CSRMV.
Expectation: Success.
"""
class CSRReduceSumNet(nn.Cell):
def __init__(self):
super(CSRReduceSumNet, self).__init__()
self.op = _csr_ops.CSRReduceSum()
def construct(self, indptr, indices, values, dense_shape, axis):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return self.op(csr_tensor, axis)
class CSRMulNet(nn.Cell):
def __init__(self):
super(CSRMulNet, self).__init__()
self.op = _csr_ops.CSRMul()
def construct(self, indptr, indices, values, dense_shape, dense):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return self.op(csr_tensor, dense)
class CSRMVNet(nn.Cell):
def __init__(self):
super(CSRMVNet, self).__init__()
self.op = _csr_ops.CSRMV()
def construct(self, indptr, indices, values, dense_shape, dense):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
return self.op(csr_tensor, dense)
csr_reducesum = _csr_ops.CSRReduceSum()
csrmv = _csr_ops.CSRMV()
indptr = Tensor([0, 1, 2])
indices = Tensor([0, 1])
values = Tensor([2, 1], dtype=mstype.float32)
dense_shape = (2, 4)
dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
dense_vector = Tensor([[1.], [1], [1], [1]], dtype=mstype.float32)
net1 = CSRReduceSumNet()
out1 = net1(indptr, indices, values, dense_shape, 1)
def test_ops_pynative(indptr, indices, values, dense_shape):
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
dense1 = csr_reducesum(csr_tensor, 1)
dense2 = csrmv(csr_tensor, dense_vector)
sparse1 = csr_tensor * dense_tensor
sparse2 = dense_tensor * csr_tensor
return dense1, dense2, sparse1, sparse2
test_ops_graph = ms_function(test_ops_pynative)
pynative_res = test_ops_pynative(indptr, indices, values, dense_shape)
graph_res = test_ops_graph(indptr, indices, values, dense_shape)
expect1 = np.array([[2.], [1.]], dtype=np.float32)
assert np.allclose(out1.asnumpy(), expect1)
net2 = CSRMulNet()
out2 = net2(indptr, indices, values, dense_shape, dense_tensor)
expect2 = np.array([2., 1.], dtype=np.float32)
assert np.allclose(out2.asnumpy(), expect2)
net3 = CSRMVNet()
out3 = net3(indptr, indices, values, dense_shape, dense_vector)
expect3 = np.array([[2.], [1.]], dtype=np.float32)
assert np.allclose(out3.asnumpy(), expect3)
expect2 = np.array([[2.], [1.]], dtype=np.float32)
expect3 = np.array([2., 1.], dtype=np.float32)
assert np.allclose(pynative_res[0].asnumpy(), expect1)
assert np.allclose(pynative_res[1].asnumpy(), expect2)
assert np.allclose(pynative_res[2].values.asnumpy(), expect3)
assert np.allclose(pynative_res[3].values.asnumpy(), expect3)
assert np.allclose(graph_res[0].asnumpy(), expect1)
assert np.allclose(graph_res[1].asnumpy(), expect2)
assert np.allclose(graph_res[2].values.asnumpy(), expect3)
assert np.allclose(graph_res[3].values.asnumpy(), expect3)