forked from mindspore-Ecosystem/mindspore
zero_copy for ge output tensor
This commit is contained in:
parent
e6ea014082
commit
96b80c7195
|
@ -423,16 +423,26 @@ bool GeGraphExecutor::RunGraph(const FuncGraphPtr &graph, const std::vector<tens
|
|||
const auto &[output_node, idx] = common::AnfAlgo::FetchRealNodeSkipMonadControl(graph_outputs[i]);
|
||||
const auto &tensor = ge_outputs[i];
|
||||
auto output_addr = AnfAlgo::GetMutableOutputAddr(output_node, idx);
|
||||
output_addr->set_ptr(device_context_->device_res_manager_->AllocateMemory(tensor->GetSize()));
|
||||
output_addr->SetSize(tensor->GetSize());
|
||||
output_addr->set_is_ptr_persisted(false);
|
||||
|
||||
if (output_addr->GetSize() < LongToSize(UlongToLong(tensor->GetSize()))) {
|
||||
MS_LOG(EXCEPTION) << "Output node " << output_node->DebugString() << "'s mem size " << output_addr->GetSize()
|
||||
<< " is less than actual output size " << tensor->GetSize();
|
||||
::ge::Placement dp = tensor->GetTensorDesc().GetPlacement();
|
||||
auto ge_data = tensor->ResetData().release();
|
||||
MS_EXCEPTION_IF_NULL(ge_data);
|
||||
if (dp != ::ge::kPlacementDevice) {
|
||||
constexpr int64_t kTensorAlignBytes = 64;
|
||||
if (reinterpret_cast<uintptr_t>(ge_data) % kTensorAlignBytes != 0) {
|
||||
MS_LOG(EXCEPTION) << "Skip zero-copy ge tensor " << reinterpret_cast<uintptr_t>(ge_data)
|
||||
<< ", bytes not aligned with expected.";
|
||||
}
|
||||
if (me_types[i] == TypeId::kObjectTypeString) {
|
||||
MS_LOG(EXCEPTION) << "It is not supported that Output node " << output_node->DebugString()
|
||||
<< "'s output data type is string now.";
|
||||
}
|
||||
MS_LOG(INFO) << "Zero-copy ge tensor " << reinterpret_cast<uintptr_t>(ge_data) << " as aligned with "
|
||||
<< kTensorAlignBytes << " types.";
|
||||
output_addr->set_ptr(ge_data);
|
||||
output_addr->SetSize(tensor->GetSize());
|
||||
output_addr->set_is_ptr_persisted(false);
|
||||
}
|
||||
// memcpy_s does not support data that more than 2GB
|
||||
(void)memcpy(reinterpret_cast<uint8_t *>(output_addr->GetMutablePtr()), tensor->GetData(), tensor->GetSize());
|
||||
|
||||
auto actual_shapes = tensor->GetTensorDesc().GetShape().GetDims();
|
||||
UpdateOutputNodeShape(output_node, idx, me_types[i], actual_shapes);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue