forked from mindspore-Ecosystem/mindspore
zero_copy for ge output tensor
This commit is contained in:
parent
e6ea014082
commit
96b80c7195
|
@ -423,16 +423,26 @@ bool GeGraphExecutor::RunGraph(const FuncGraphPtr &graph, const std::vector<tens
|
||||||
const auto &[output_node, idx] = common::AnfAlgo::FetchRealNodeSkipMonadControl(graph_outputs[i]);
|
const auto &[output_node, idx] = common::AnfAlgo::FetchRealNodeSkipMonadControl(graph_outputs[i]);
|
||||||
const auto &tensor = ge_outputs[i];
|
const auto &tensor = ge_outputs[i];
|
||||||
auto output_addr = AnfAlgo::GetMutableOutputAddr(output_node, idx);
|
auto output_addr = AnfAlgo::GetMutableOutputAddr(output_node, idx);
|
||||||
output_addr->set_ptr(device_context_->device_res_manager_->AllocateMemory(tensor->GetSize()));
|
::ge::Placement dp = tensor->GetTensorDesc().GetPlacement();
|
||||||
|
auto ge_data = tensor->ResetData().release();
|
||||||
|
MS_EXCEPTION_IF_NULL(ge_data);
|
||||||
|
if (dp != ::ge::kPlacementDevice) {
|
||||||
|
constexpr int64_t kTensorAlignBytes = 64;
|
||||||
|
if (reinterpret_cast<uintptr_t>(ge_data) % kTensorAlignBytes != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Skip zero-copy ge tensor " << reinterpret_cast<uintptr_t>(ge_data)
|
||||||
|
<< ", bytes not aligned with expected.";
|
||||||
|
}
|
||||||
|
if (me_types[i] == TypeId::kObjectTypeString) {
|
||||||
|
MS_LOG(EXCEPTION) << "It is not supported that Output node " << output_node->DebugString()
|
||||||
|
<< "'s output data type is string now.";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Zero-copy ge tensor " << reinterpret_cast<uintptr_t>(ge_data) << " as aligned with "
|
||||||
|
<< kTensorAlignBytes << " types.";
|
||||||
|
output_addr->set_ptr(ge_data);
|
||||||
output_addr->SetSize(tensor->GetSize());
|
output_addr->SetSize(tensor->GetSize());
|
||||||
output_addr->set_is_ptr_persisted(false);
|
output_addr->set_is_ptr_persisted(false);
|
||||||
|
|
||||||
if (output_addr->GetSize() < LongToSize(UlongToLong(tensor->GetSize()))) {
|
|
||||||
MS_LOG(EXCEPTION) << "Output node " << output_node->DebugString() << "'s mem size " << output_addr->GetSize()
|
|
||||||
<< " is less than actual output size " << tensor->GetSize();
|
|
||||||
}
|
}
|
||||||
// memcpy_s does not support data that more than 2GB
|
|
||||||
(void)memcpy(reinterpret_cast<uint8_t *>(output_addr->GetMutablePtr()), tensor->GetData(), tensor->GetSize());
|
|
||||||
auto actual_shapes = tensor->GetTensorDesc().GetShape().GetDims();
|
auto actual_shapes = tensor->GetTensorDesc().GetShape().GetDims();
|
||||||
UpdateOutputNodeShape(output_node, idx, me_types[i], actual_shapes);
|
UpdateOutputNodeShape(output_node, idx, me_types[i], actual_shapes);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue