forked from mindspore-Ecosystem/mindspore
support pynative csr op
This commit is contained in:
parent
3b3a6da5da
commit
feefdae8e3
|
@ -52,11 +52,11 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
#endif
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto common_pm = std::make_shared<PassManager>("common_pm");
|
||||
common_pm->AddPass(std::make_shared<SparseProcess>());
|
||||
common_pm->AddPass(std::make_shared<AddDynamicShapeAttr>());
|
||||
common_pm->AddPass(std::make_shared<ReduceSumOptimizer>());
|
||||
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());
|
||||
common_pm->AddPass(std::make_shared<CustomOpConstInputToAttr>());
|
||||
common_pm->AddPass(std::make_shared<SparseProcess>());
|
||||
common_pm->AddPass(std::make_shared<ConvertAttrToUnifyMindIR>());
|
||||
common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>());
|
||||
common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>());
|
||||
|
|
|
@ -54,9 +54,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(prim::kPrimReduceAny->name(), {1});
|
||||
Register(prim::kPrimUnsortedSegmentMin->name(), {2});
|
||||
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
|
||||
Register(prim::kPrimCSRMul->name(), {3});
|
||||
Register(prim::kPrimCSRReduceSum->name(), {3, 4});
|
||||
Register(prim::kPrimCSRMV->name(), {3});
|
||||
Register(prim::kPrimCSRReduceSum->name(), {1});
|
||||
Register(kSparseGatherV2OpName, {2});
|
||||
Register(kUnsortedSegmentProdOpName, {2});
|
||||
Register(kSimpleMeanGradOpName, {1});
|
||||
|
|
|
@ -65,9 +65,6 @@ bool SplitValueNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs)
|
|||
new_inputs->push_back(NewValueNodeAndSetAbstract(csr_tensor->GetIndptr(), csr_abs->indptr()));
|
||||
new_inputs->push_back(NewValueNodeAndSetAbstract(csr_tensor->GetIndices(), csr_abs->indices()));
|
||||
new_inputs->push_back(NewValueNodeAndSetAbstract(csr_tensor->GetValues(), csr_abs->values()));
|
||||
auto shape_node = NewValueNode(csr_tensor->shape());
|
||||
shape_node->set_abstract(csr_abs->dense_shape());
|
||||
new_inputs->push_back(shape_node);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -80,7 +77,9 @@ bool SplitCNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs) {
|
|||
return false;
|
||||
|
||||
auto sparse_inputs = cnode->inputs();
|
||||
for (size_t j = 1; j < sparse_inputs.size(); ++j) {
|
||||
// skip the last input, as it always represents shape, and has already been
|
||||
// registered as primitive attribute.
|
||||
for (size_t j = 1; j < sparse_inputs.size() - 1; ++j) {
|
||||
new_inputs->push_back(sparse_inputs[j]);
|
||||
}
|
||||
return true;
|
||||
|
@ -107,6 +106,7 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
|
|||
(void)inputs.insert(inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
auto new_node = cnode->func_graph()->NewCNode(inputs);
|
||||
auto abs_sparse = dyn_cast<abstract::AbstractCSRTensor>(node->abstract());
|
||||
MS_EXCEPTION_IF_NULL(abs_sparse);
|
||||
std::vector<AbstractBasePtr> abstract_list{abs_sparse->indptr(), abs_sparse->indices(), abs_sparse->values(),
|
||||
abs_sparse->dense_shape()};
|
||||
auto abs_res = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
|
|
|
@ -483,6 +483,25 @@ void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tenso
|
|||
input_tensors->emplace_back(tensor_ptr);
|
||||
}
|
||||
|
||||
void ConvertCSRTensorToTensorList(const py::object &input_object, const PrimitivePtr &op_prim,
|
||||
std::vector<tensor::TensorPtr> *input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(op_prim);
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
if (!py::isinstance<tensor::CSRTensor>(input_object)) {
|
||||
MS_LOG(EXCEPTION) << "The input should be a csr_tensor! ";
|
||||
}
|
||||
auto input_names = op_prim->GetAttr(kAttrInputNames);
|
||||
if (input_names == nullptr) {
|
||||
MS_LOG(DEBUG) << "input_names are nullptr";
|
||||
return;
|
||||
}
|
||||
auto csr_inputs = py::cast<tensor::CSRTensor>(input_object);
|
||||
input_tensors->emplace_back(csr_inputs.GetIndptr());
|
||||
input_tensors->emplace_back(csr_inputs.GetIndices());
|
||||
input_tensors->emplace_back(csr_inputs.GetValues());
|
||||
op_prim->set_attr("is_csr", MakeValue(true));
|
||||
}
|
||||
|
||||
void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, int64_t *const tensor_mask) {
|
||||
MS_EXCEPTION_IF_NULL(op_prim);
|
||||
|
@ -535,6 +554,9 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr
|
|||
} else if (py::isinstance<py::tuple>(input_object)) {
|
||||
ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
|
||||
return;
|
||||
} else if (py::isinstance<tensor::CSRTensor>(input_object)) {
|
||||
ConvertCSRTensorToTensorList(input_object, op_prim, input_tensors);
|
||||
return;
|
||||
} else if (py::isinstance<py::none>(input_object)) {
|
||||
return;
|
||||
} else {
|
||||
|
@ -884,7 +906,8 @@ py::object GetDstType(const TypeId &type_id) {
|
|||
}
|
||||
|
||||
bool IsPyObjTypeInvalid(const py::object &obj) {
|
||||
return !py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj);
|
||||
return !py::isinstance<tensor::Tensor>(obj) && !py::isinstance<tensor::CSRTensor>(obj) &&
|
||||
!py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj);
|
||||
}
|
||||
|
||||
inline bool IsNopPrim(const std::string &op_name) {
|
||||
|
|
|
@ -90,8 +90,7 @@ const mindspore::HashMap<std::string, int64_t> sparse_attr_map = {{prim::kPrimCS
|
|||
// make_sparse_set records all make_sparse primitives, and tries to replace
|
||||
// make_sparse to make_tuple, used in backend common optimization pass:
|
||||
// sparse_process.cc
|
||||
const mindspore::HashSet<std::string> make_sparse_set = {{prim::kPrimMakeCSRTensor->name()},
|
||||
{prim::kPrimMakeSparseTensor->name()}};
|
||||
const mindspore::HashSet<std::string> make_sparse_set = {{prim::kPrimMakeCSRTensor->name()}};
|
||||
// sparse_op_set records all sparse_compute operators, which takes sparsetensor
|
||||
// and (possibly) dense tensors, used in backend common optimization pass:
|
||||
// sparse_process.cc
|
||||
|
|
|
@ -2321,6 +2321,10 @@ class CSRTensor(CSRTensor_):
|
|||
return CSRTensor_.__repr__(self)
|
||||
return ''
|
||||
|
||||
def __mul__(self, other):
|
||||
res = tensor_operator_registry.get('csr_mul')(self, other)
|
||||
return CSRTensor(self.indptr, self.indices, res, self.shape)
|
||||
|
||||
@property
|
||||
def indptr(self):
|
||||
return Tensor(self._indptr)
|
||||
|
|
|
@ -21,7 +21,7 @@ from ...composite import base
|
|||
from ...operations._inner_ops import TensorCopySlices, SliceGetItem
|
||||
from ....common import dtype as mstype
|
||||
from ....common._register_for_tensor import tensor_operator_registry
|
||||
from ....common.tensor import Tensor
|
||||
from ....common.tensor import Tensor, CSRTensor
|
||||
|
||||
slice_get_item = SliceGetItem()
|
||||
hyper_map = base.HyperMap()
|
||||
|
@ -96,6 +96,8 @@ def _tensor_sub(self, other):
|
|||
def _tensor_mul(self, other):
|
||||
if isinstance(other, (tuple, list)):
|
||||
other = sequence_to_tensor(other, F.dtype(self))
|
||||
elif isinstance(other, CSRTensor):
|
||||
return other * self
|
||||
return F.mul(self, other)
|
||||
|
||||
|
||||
|
|
|
@ -478,6 +478,7 @@ tensor_operator_registry.register('gather_nd', gather_nd)
|
|||
tensor_operator_registry.register('stack', P.Stack)
|
||||
tensor_operator_registry.register('log', log)
|
||||
tensor_operator_registry.register('floor', floor)
|
||||
|
||||
# support sparse tensor operators
|
||||
tensor_operator_registry.register('csr_mul', csr_mul)
|
||||
__all__ = [name for name in dir() if name[0] != "_"]
|
||||
__all__.remove('Primitive')
|
||||
|
|
|
@ -211,51 +211,37 @@ def test_csr_ops():
|
|||
Description: Test CSRReduceSum, CSRMul, CSRMV.
|
||||
Expectation: Success.
|
||||
"""
|
||||
class CSRReduceSumNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CSRReduceSumNet, self).__init__()
|
||||
self.op = _csr_ops.CSRReduceSum()
|
||||
|
||||
def construct(self, indptr, indices, values, dense_shape, axis):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return self.op(csr_tensor, axis)
|
||||
|
||||
class CSRMulNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CSRMulNet, self).__init__()
|
||||
self.op = _csr_ops.CSRMul()
|
||||
|
||||
def construct(self, indptr, indices, values, dense_shape, dense):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return self.op(csr_tensor, dense)
|
||||
|
||||
class CSRMVNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CSRMVNet, self).__init__()
|
||||
self.op = _csr_ops.CSRMV()
|
||||
|
||||
def construct(self, indptr, indices, values, dense_shape, dense):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return self.op(csr_tensor, dense)
|
||||
csr_reducesum = _csr_ops.CSRReduceSum()
|
||||
csrmv = _csr_ops.CSRMV()
|
||||
|
||||
indptr = Tensor([0, 1, 2])
|
||||
indices = Tensor([0, 1])
|
||||
values = Tensor([2, 1], dtype=mstype.float32)
|
||||
dense_shape = (2, 4)
|
||||
|
||||
dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
|
||||
dense_vector = Tensor([[1.], [1], [1], [1]], dtype=mstype.float32)
|
||||
|
||||
net1 = CSRReduceSumNet()
|
||||
out1 = net1(indptr, indices, values, dense_shape, 1)
|
||||
def test_ops_pynative(indptr, indices, values, dense_shape):
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
dense1 = csr_reducesum(csr_tensor, 1)
|
||||
dense2 = csrmv(csr_tensor, dense_vector)
|
||||
sparse1 = csr_tensor * dense_tensor
|
||||
sparse2 = dense_tensor * csr_tensor
|
||||
return dense1, dense2, sparse1, sparse2
|
||||
|
||||
test_ops_graph = ms_function(test_ops_pynative)
|
||||
|
||||
pynative_res = test_ops_pynative(indptr, indices, values, dense_shape)
|
||||
graph_res = test_ops_graph(indptr, indices, values, dense_shape)
|
||||
expect1 = np.array([[2.], [1.]], dtype=np.float32)
|
||||
assert np.allclose(out1.asnumpy(), expect1)
|
||||
|
||||
net2 = CSRMulNet()
|
||||
out2 = net2(indptr, indices, values, dense_shape, dense_tensor)
|
||||
expect2 = np.array([2., 1.], dtype=np.float32)
|
||||
assert np.allclose(out2.asnumpy(), expect2)
|
||||
|
||||
net3 = CSRMVNet()
|
||||
out3 = net3(indptr, indices, values, dense_shape, dense_vector)
|
||||
expect3 = np.array([[2.], [1.]], dtype=np.float32)
|
||||
assert np.allclose(out3.asnumpy(), expect3)
|
||||
expect2 = np.array([[2.], [1.]], dtype=np.float32)
|
||||
expect3 = np.array([2., 1.], dtype=np.float32)
|
||||
assert np.allclose(pynative_res[0].asnumpy(), expect1)
|
||||
assert np.allclose(pynative_res[1].asnumpy(), expect2)
|
||||
assert np.allclose(pynative_res[2].values.asnumpy(), expect3)
|
||||
assert np.allclose(pynative_res[3].values.asnumpy(), expect3)
|
||||
assert np.allclose(graph_res[0].asnumpy(), expect1)
|
||||
assert np.allclose(graph_res[1].asnumpy(), expect2)
|
||||
assert np.allclose(graph_res[2].values.asnumpy(), expect3)
|
||||
assert np.allclose(graph_res[3].values.asnumpy(), expect3)
|
||||
|
|
Loading…
Reference in New Issue