support pynative with sparse return
This commit is contained in:
parent
7f54da8b6a
commit
3ec16ca135
|
@ -377,6 +377,15 @@ void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::object &cell
|
||||||
void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const {
|
void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const {
|
||||||
MS_EXCEPTION_IF_NULL(v);
|
MS_EXCEPTION_IF_NULL(v);
|
||||||
auto output_node = GetObjNode(v, obj_id);
|
auto output_node = GetObjNode(v, obj_id);
|
||||||
|
if (v->isa<tensor::CSRTensor>()) {
|
||||||
|
auto csr_tensorptr = v->cast<tensor::CSRTensorPtr>();
|
||||||
|
auto value_ptr = csr_tensorptr->GetValues();
|
||||||
|
output_node = GetObjNode(value_ptr, PyNativeAlgo::Common::GetIdByValue(value_ptr));
|
||||||
|
} else if (v->isa<tensor::COOTensor>()) {
|
||||||
|
auto coo_tensorptr = v->cast<tensor::COOTensorPtr>();
|
||||||
|
auto value_ptr = coo_tensorptr->GetValues();
|
||||||
|
output_node = GetObjNode(value_ptr, PyNativeAlgo::Common::GetIdByValue(value_ptr));
|
||||||
|
}
|
||||||
MS_EXCEPTION_IF_NULL(output_node);
|
MS_EXCEPTION_IF_NULL(output_node);
|
||||||
if (top_cell()->dynamic_shape()) {
|
if (top_cell()->dynamic_shape()) {
|
||||||
abstract::AbstractBasePtr last_node_abs = nullptr;
|
abstract::AbstractBasePtr last_node_abs = nullptr;
|
||||||
|
|
|
@ -5796,6 +5796,7 @@ class COOTensor(COOTensor_):
|
||||||
validator.check_coo_tensor_input(indices, values, shape)
|
validator.check_coo_tensor_input(indices, values, shape)
|
||||||
validator.check_coo_tensor_shape(indices.shape, values.shape, shape)
|
validator.check_coo_tensor_shape(indices.shape, values.shape, shape)
|
||||||
validator.check_coo_tensor_dtype(indices.dtype)
|
validator.check_coo_tensor_dtype(indices.dtype)
|
||||||
|
indices = tensor_operator_registry.get('stop_gradient')(indices)
|
||||||
COOTensor_.__init__(self, indices, values, shape)
|
COOTensor_.__init__(self, indices, values, shape)
|
||||||
self.init_finished = True
|
self.init_finished = True
|
||||||
|
|
||||||
|
@ -6106,6 +6107,8 @@ class CSRTensor(CSRTensor_):
|
||||||
validator.check_csr_tensor_input(indptr, indices, values, shape)
|
validator.check_csr_tensor_input(indptr, indices, values, shape)
|
||||||
validator.check_csr_tensor_shape(indptr.shape, indices.shape, values.shape, shape)
|
validator.check_csr_tensor_shape(indptr.shape, indices.shape, values.shape, shape)
|
||||||
validator.check_csr_tensor_dtype(indptr.dtype, indices.dtype)
|
validator.check_csr_tensor_dtype(indptr.dtype, indices.dtype)
|
||||||
|
indptr = tensor_operator_registry.get('stop_gradient')(indptr)
|
||||||
|
indices = tensor_operator_registry.get('stop_gradient')(indices)
|
||||||
CSRTensor_.__init__(self, indptr, indices, values, shape)
|
CSRTensor_.__init__(self, indptr, indices, values, shape)
|
||||||
self.init_finished = True
|
self.init_finished = True
|
||||||
|
|
||||||
|
|
|
@ -396,6 +396,7 @@ tensor_operator_registry.register('split', P.Split)
|
||||||
tensor_operator_registry.register('select', P.Select)
|
tensor_operator_registry.register('select', P.Select)
|
||||||
tensor_operator_registry.register('zeros_like', P.ZerosLike)
|
tensor_operator_registry.register('zeros_like', P.ZerosLike)
|
||||||
tensor_operator_registry.register('scalar_to_tensor', scalar_to_tensor)
|
tensor_operator_registry.register('scalar_to_tensor', scalar_to_tensor)
|
||||||
|
tensor_operator_registry.register('stop_gradient', stop_gradient)
|
||||||
tensor_operator_registry.register('masked_fill', masked_fill)
|
tensor_operator_registry.register('masked_fill', masked_fill)
|
||||||
tensor_operator_registry.register('masked_select', masked_select)
|
tensor_operator_registry.register('masked_select', masked_select)
|
||||||
tensor_operator_registry.register('nonzero', nonzero)
|
tensor_operator_registry.register('nonzero', nonzero)
|
||||||
|
|
Loading…
Reference in New Issue