support pynative with sparse return

This commit is contained in:
yanglf1121 2022-09-14 17:29:10 +08:00
parent 7f54da8b6a
commit 3ec16ca135
3 changed files with 13 additions and 0 deletions

View File

@ -377,6 +377,15 @@ void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::object &cell
void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const { void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const {
MS_EXCEPTION_IF_NULL(v); MS_EXCEPTION_IF_NULL(v);
auto output_node = GetObjNode(v, obj_id); auto output_node = GetObjNode(v, obj_id);
if (v->isa<tensor::CSRTensor>()) {
auto csr_tensorptr = v->cast<tensor::CSRTensorPtr>();
auto value_ptr = csr_tensorptr->GetValues();
output_node = GetObjNode(value_ptr, PyNativeAlgo::Common::GetIdByValue(value_ptr));
} else if (v->isa<tensor::COOTensor>()) {
auto coo_tensorptr = v->cast<tensor::COOTensorPtr>();
auto value_ptr = coo_tensorptr->GetValues();
output_node = GetObjNode(value_ptr, PyNativeAlgo::Common::GetIdByValue(value_ptr));
}
MS_EXCEPTION_IF_NULL(output_node); MS_EXCEPTION_IF_NULL(output_node);
if (top_cell()->dynamic_shape()) { if (top_cell()->dynamic_shape()) {
abstract::AbstractBasePtr last_node_abs = nullptr; abstract::AbstractBasePtr last_node_abs = nullptr;

View File

@ -5796,6 +5796,7 @@ class COOTensor(COOTensor_):
validator.check_coo_tensor_input(indices, values, shape) validator.check_coo_tensor_input(indices, values, shape)
validator.check_coo_tensor_shape(indices.shape, values.shape, shape) validator.check_coo_tensor_shape(indices.shape, values.shape, shape)
validator.check_coo_tensor_dtype(indices.dtype) validator.check_coo_tensor_dtype(indices.dtype)
indices = tensor_operator_registry.get('stop_gradient')(indices)
COOTensor_.__init__(self, indices, values, shape) COOTensor_.__init__(self, indices, values, shape)
self.init_finished = True self.init_finished = True
@ -6106,6 +6107,8 @@ class CSRTensor(CSRTensor_):
validator.check_csr_tensor_input(indptr, indices, values, shape) validator.check_csr_tensor_input(indptr, indices, values, shape)
validator.check_csr_tensor_shape(indptr.shape, indices.shape, values.shape, shape) validator.check_csr_tensor_shape(indptr.shape, indices.shape, values.shape, shape)
validator.check_csr_tensor_dtype(indptr.dtype, indices.dtype) validator.check_csr_tensor_dtype(indptr.dtype, indices.dtype)
indptr = tensor_operator_registry.get('stop_gradient')(indptr)
indices = tensor_operator_registry.get('stop_gradient')(indices)
CSRTensor_.__init__(self, indptr, indices, values, shape) CSRTensor_.__init__(self, indptr, indices, values, shape)
self.init_finished = True self.init_finished = True

View File

@ -396,6 +396,7 @@ tensor_operator_registry.register('split', P.Split)
tensor_operator_registry.register('select', P.Select) tensor_operator_registry.register('select', P.Select)
tensor_operator_registry.register('zeros_like', P.ZerosLike) tensor_operator_registry.register('zeros_like', P.ZerosLike)
tensor_operator_registry.register('scalar_to_tensor', scalar_to_tensor) tensor_operator_registry.register('scalar_to_tensor', scalar_to_tensor)
tensor_operator_registry.register('stop_gradient', stop_gradient)
tensor_operator_registry.register('masked_fill', masked_fill) tensor_operator_registry.register('masked_fill', masked_fill)
tensor_operator_registry.register('masked_select', masked_select) tensor_operator_registry.register('masked_select', masked_select)
tensor_operator_registry.register('nonzero', nonzero) tensor_operator_registry.register('nonzero', nonzero)