optimize updateoutput in gpu

This commit is contained in:
chujinjin 2020-07-31 11:42:14 +08:00
parent 6eddd65cf1
commit 1cb8d9daf3
3 changed files with 20 additions and 4 deletions

View File

@ -426,7 +426,12 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors;
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
std::copy(pre_output_tensors.begin(), pre_output_tensors.end(), std::back_inserter(outputs));
for (auto &pre_output : pre_output_tensors) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false);
outputs.emplace_back(tensor);
}
} else {
UpdateOutputs(graph, &outputs, input_tensors);
}

View File

@ -300,7 +300,18 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
}
// Fetch outputs
VectorRef outputs;
UpdateOutputs(kernel_graph, &outputs, input_tensors);
if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors;
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
for (auto &pre_output : pre_output_tensors) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false);
outputs.emplace_back(tensor);
}
} else {
UpdateOutputs(kernel_graph, &outputs, input_tensors);
}
// Trans output to tuple
auto output_tensors = TransformBaseRefListToTuple(outputs);
if (!utils::isa<PyObjectRef>(output_tensors) ||

View File

@ -565,9 +565,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
if (session == nullptr) {
session = session::SessionFactory::Get().Create(device_target);
MS_EXCEPTION_IF_NULL(session);
session->Init(ms_context->device_id());
}
MS_EXCEPTION_IF_NULL(session);
session->Init(ms_context->device_id());
std::vector<tensor::TensorPtr> input_tensors;
std::vector<int> tensors_mask;