!41975 support SparseTensor return in PyNative Back-Propagation

Merge pull request !41975 from 杨林枫/support_pynative_sparse_bprop
This commit is contained in:
i-robot 2022-09-20 01:36:01 +00:00 committed by Gitee
commit 936d88acae
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 13 additions and 0 deletions

View File

@ -377,6 +377,15 @@ void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::object &cell
void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const {
MS_EXCEPTION_IF_NULL(v);
auto output_node = GetObjNode(v, obj_id);
if (v->isa<tensor::CSRTensor>()) {
auto csr_tensorptr = v->cast<tensor::CSRTensorPtr>();
auto value_ptr = csr_tensorptr->GetValues();
output_node = GetObjNode(value_ptr, PyNativeAlgo::Common::GetIdByValue(value_ptr));
} else if (v->isa<tensor::COOTensor>()) {
auto coo_tensorptr = v->cast<tensor::COOTensorPtr>();
auto value_ptr = coo_tensorptr->GetValues();
output_node = GetObjNode(value_ptr, PyNativeAlgo::Common::GetIdByValue(value_ptr));
}
MS_EXCEPTION_IF_NULL(output_node);
if (top_cell()->dynamic_shape()) {
abstract::AbstractBasePtr last_node_abs = nullptr;

View File

@ -5848,6 +5848,7 @@ class COOTensor(COOTensor_):
validator.check_coo_tensor_input(indices, values, shape)
validator.check_coo_tensor_shape(indices.shape, values.shape, shape)
validator.check_coo_tensor_dtype(indices.dtype)
indices = tensor_operator_registry.get('stop_gradient')(indices)
COOTensor_.__init__(self, indices, values, shape)
self.init_finished = True
@ -6158,6 +6159,8 @@ class CSRTensor(CSRTensor_):
validator.check_csr_tensor_input(indptr, indices, values, shape)
validator.check_csr_tensor_shape(indptr.shape, indices.shape, values.shape, shape)
validator.check_csr_tensor_dtype(indptr.dtype, indices.dtype)
indptr = tensor_operator_registry.get('stop_gradient')(indptr)
indices = tensor_operator_registry.get('stop_gradient')(indices)
CSRTensor_.__init__(self, indptr, indices, values, shape)
self.init_finished = True

View File

@ -397,6 +397,7 @@ tensor_operator_registry.register('split', P.Split)
tensor_operator_registry.register('select', P.Select)
tensor_operator_registry.register('zeros_like', P.ZerosLike)
tensor_operator_registry.register('scalar_to_tensor', scalar_to_tensor)
tensor_operator_registry.register('stop_gradient', stop_gradient)
tensor_operator_registry.register('masked_fill', masked_fill)
tensor_operator_registry.register('masked_select', masked_select)
tensor_operator_registry.register('nonzero', nonzero)