forked from OSSInnovation/mindspore
optimize updateoutput in gpu
This commit is contained in:
parent
6eddd65cf1
commit
1cb8d9daf3
|
@ -426,7 +426,12 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
|
|||
if (op_run_info.value != nullptr) {
|
||||
std::vector<tensor::TensorPtr> pre_output_tensors;
|
||||
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
|
||||
std::copy(pre_output_tensors.begin(), pre_output_tensors.end(), std::back_inserter(outputs));
|
||||
for (auto &pre_output : pre_output_tensors) {
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
|
||||
tensor->set_device_address(pre_output->device_address());
|
||||
tensor->set_dirty(false);
|
||||
outputs.emplace_back(tensor);
|
||||
}
|
||||
} else {
|
||||
UpdateOutputs(graph, &outputs, input_tensors);
|
||||
}
|
||||
|
|
|
@ -300,7 +300,18 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
|
|||
}
|
||||
// Fetch outputs
|
||||
VectorRef outputs;
|
||||
UpdateOutputs(kernel_graph, &outputs, input_tensors);
|
||||
if (op_run_info.value != nullptr) {
|
||||
std::vector<tensor::TensorPtr> pre_output_tensors;
|
||||
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
|
||||
for (auto &pre_output : pre_output_tensors) {
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
|
||||
tensor->set_device_address(pre_output->device_address());
|
||||
tensor->set_dirty(false);
|
||||
outputs.emplace_back(tensor);
|
||||
}
|
||||
} else {
|
||||
UpdateOutputs(kernel_graph, &outputs, input_tensors);
|
||||
}
|
||||
// Trans output to tuple
|
||||
auto output_tensors = TransformBaseRefListToTuple(outputs);
|
||||
if (!utils::isa<PyObjectRef>(output_tensors) ||
|
||||
|
|
|
@ -565,9 +565,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
|
||||
if (session == nullptr) {
|
||||
session = session::SessionFactory::Get().Create(device_target);
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
session->Init(ms_context->device_id());
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
session->Init(ms_context->device_id());
|
||||
|
||||
std::vector<tensor::TensorPtr> input_tensors;
|
||||
std::vector<int> tensors_mask;
|
||||
|
|
Loading…
Reference in New Issue