forked from mindspore-Ecosystem/mindspore
!41975 support SparseTensor return in PyNative Back-Propagation
Merge pull request !41975 from 杨林枫/support_pynative_sparse_bprop
This commit is contained in:
commit
936d88acae
|
@ -377,6 +377,15 @@ void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::object &cell
|
|||
void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const {
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
auto output_node = GetObjNode(v, obj_id);
|
||||
if (v->isa<tensor::CSRTensor>()) {
|
||||
auto csr_tensorptr = v->cast<tensor::CSRTensorPtr>();
|
||||
auto value_ptr = csr_tensorptr->GetValues();
|
||||
output_node = GetObjNode(value_ptr, PyNativeAlgo::Common::GetIdByValue(value_ptr));
|
||||
} else if (v->isa<tensor::COOTensor>()) {
|
||||
auto coo_tensorptr = v->cast<tensor::COOTensorPtr>();
|
||||
auto value_ptr = coo_tensorptr->GetValues();
|
||||
output_node = GetObjNode(value_ptr, PyNativeAlgo::Common::GetIdByValue(value_ptr));
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
if (top_cell()->dynamic_shape()) {
|
||||
abstract::AbstractBasePtr last_node_abs = nullptr;
|
||||
|
|
|
@ -5848,6 +5848,7 @@ class COOTensor(COOTensor_):
|
|||
validator.check_coo_tensor_input(indices, values, shape)
|
||||
validator.check_coo_tensor_shape(indices.shape, values.shape, shape)
|
||||
validator.check_coo_tensor_dtype(indices.dtype)
|
||||
indices = tensor_operator_registry.get('stop_gradient')(indices)
|
||||
COOTensor_.__init__(self, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
|
@ -6158,6 +6159,8 @@ class CSRTensor(CSRTensor_):
|
|||
validator.check_csr_tensor_input(indptr, indices, values, shape)
|
||||
validator.check_csr_tensor_shape(indptr.shape, indices.shape, values.shape, shape)
|
||||
validator.check_csr_tensor_dtype(indptr.dtype, indices.dtype)
|
||||
indptr = tensor_operator_registry.get('stop_gradient')(indptr)
|
||||
indices = tensor_operator_registry.get('stop_gradient')(indices)
|
||||
CSRTensor_.__init__(self, indptr, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
|
|
|
@ -397,6 +397,7 @@ tensor_operator_registry.register('split', P.Split)
|
|||
tensor_operator_registry.register('select', P.Select)
|
||||
tensor_operator_registry.register('zeros_like', P.ZerosLike)
|
||||
tensor_operator_registry.register('scalar_to_tensor', scalar_to_tensor)
|
||||
tensor_operator_registry.register('stop_gradient', stop_gradient)
|
||||
tensor_operator_registry.register('masked_fill', masked_fill)
|
||||
tensor_operator_registry.register('masked_select', masked_select)
|
||||
tensor_operator_registry.register('nonzero', nonzero)
|
||||
|
|
Loading…
Reference in New Issue