zero_copy for ge output tensor

This commit is contained in:
xiao_yao1994 2023-02-18 17:22:54 +08:00
parent e6ea014082
commit 96b80c7195
1 changed files with 19 additions and 9 deletions

View File

@ -423,16 +423,26 @@ bool GeGraphExecutor::RunGraph(const FuncGraphPtr &graph, const std::vector<tens
const auto &[output_node, idx] = common::AnfAlgo::FetchRealNodeSkipMonadControl(graph_outputs[i]); const auto &[output_node, idx] = common::AnfAlgo::FetchRealNodeSkipMonadControl(graph_outputs[i]);
const auto &tensor = ge_outputs[i]; const auto &tensor = ge_outputs[i];
auto output_addr = AnfAlgo::GetMutableOutputAddr(output_node, idx); auto output_addr = AnfAlgo::GetMutableOutputAddr(output_node, idx);
output_addr->set_ptr(device_context_->device_res_manager_->AllocateMemory(tensor->GetSize())); ::ge::Placement dp = tensor->GetTensorDesc().GetPlacement();
auto ge_data = tensor->ResetData().release();
MS_EXCEPTION_IF_NULL(ge_data);
if (dp != ::ge::kPlacementDevice) {
constexpr int64_t kTensorAlignBytes = 64;
if (reinterpret_cast<uintptr_t>(ge_data) % kTensorAlignBytes != 0) {
MS_LOG(EXCEPTION) << "Skip zero-copy ge tensor " << reinterpret_cast<uintptr_t>(ge_data)
<< ", bytes not aligned with expected.";
}
if (me_types[i] == TypeId::kObjectTypeString) {
MS_LOG(EXCEPTION) << "It is not supported that Output node " << output_node->DebugString()
<< "'s output data type is string now.";
}
MS_LOG(INFO) << "Zero-copy ge tensor " << reinterpret_cast<uintptr_t>(ge_data) << " as aligned with "
<< kTensorAlignBytes << " types.";
output_addr->set_ptr(ge_data);
output_addr->SetSize(tensor->GetSize()); output_addr->SetSize(tensor->GetSize());
output_addr->set_is_ptr_persisted(false); output_addr->set_is_ptr_persisted(false);
if (output_addr->GetSize() < LongToSize(UlongToLong(tensor->GetSize()))) {
MS_LOG(EXCEPTION) << "Output node " << output_node->DebugString() << "'s mem size " << output_addr->GetSize()
<< " is less than actual output size " << tensor->GetSize();
} }
// memcpy_s does not support data that more than 2GB
(void)memcpy(reinterpret_cast<uint8_t *>(output_addr->GetMutablePtr()), tensor->GetData(), tensor->GetSize());
auto actual_shapes = tensor->GetTensorDesc().GetShape().GetDims(); auto actual_shapes = tensor->GetTensorDesc().GetShape().GetDims();
UpdateOutputNodeShape(output_node, idx, me_types[i], actual_shapes); UpdateOutputNodeShape(output_node, idx, me_types[i], actual_shapes);
} }