forked from mindspore-Ecosystem/mindspore
fix-bug-of-null-output-addr-of-tensor-in-valuenode
This commit is contained in:
parent
3da8cc98c5
commit
28e3121fbc
|
@ -157,8 +157,6 @@ def tuple_to_array(x):
|
|||
|
||||
def stop_gradient(x):
|
||||
"""Implement `stop_gradient`."""
|
||||
if isinstance(x, Tensor):
|
||||
return Tensor(x.asnumpy())
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -314,6 +314,18 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
py::tuple err_ret(0);
|
||||
return std::move(err_ret);
|
||||
}
|
||||
if (op_exec_info->op_name == "stop_gradient" && py::isinstance<tensor::Tensor>(result)) {
|
||||
py::tuple tuple_result(1);
|
||||
auto tensor = py::cast<tensor::TensorPtr>(result);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
|
||||
new_tensor->set_device_address(tensor->device_address());
|
||||
new_tensor->set_sync_status(tensor->sync_status());
|
||||
tuple_result[0] = new_tensor;
|
||||
*status = PYNATIVE_SUCCESS;
|
||||
MS_LOG(INFO) << "RunOpInVM end";
|
||||
return std::move(tuple_result);
|
||||
}
|
||||
|
||||
// execute op
|
||||
py::tuple tuple_result = py::make_tuple(result);
|
||||
|
|
Loading…
Reference in New Issue