!44373 runtime support the input and output of ref node in different subgraph

Merge pull request !44373 from limingqi107/bug_fix4
This commit is contained in:
i-robot 2022-10-26 01:41:26 +00:00 committed by Gitee
commit 8ad405032d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 177 additions and 76 deletions

View File

@ -545,7 +545,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetMutableOutputAddr(output_idx);
if (addr == nullptr) {
MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString()
MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
<< " output addr is not exist." << trace::DumpSourceLines(node);
}
return addr;

View File

@ -51,8 +51,7 @@ class ApplyProximalGradientDescentCpuKernelMod : public NativeCpuKernelMod {
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutInRef(0, 0)
.AddOutInRef(1, 1),
.AddOutInRef(0, 0),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
@ -60,8 +59,7 @@ class ApplyProximalGradientDescentCpuKernelMod : public NativeCpuKernelMod {
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutInRef(0, 0)
.AddOutInRef(1, 1)};
.AddOutInRef(0, 0)};
return support_list;
}

View File

@ -282,8 +282,15 @@ void DeviceAddressUtils::UpdateDeviceAddress(const session::AnfWithOutIndex &cur
MS_EXCEPTION_IF_NULL(cur_node_output_addr);
if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
if (origin_node_output_addr->GetDeviceType() != cur_node_output_addr->GetDeviceType()) {
MS_LOG(EXCEPTION) << "Device type is not equal: ref origin kernel is " << origin_pair.first->fullname_with_scope()
<< ", index is " << origin_pair.second << ", device type is "
<< origin_node_output_addr->GetDeviceType() << "; cur kernel is "
<< cur_pair.first->fullname_with_scope() << ", index is " << cur_pair.second
<< ", device type is " << cur_node_output_addr->GetDeviceType();
}
MS_LOG(INFO) << "Update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
<< ", index is " << origin_pair.second << ", cur kernel is " << cur_pair.first->fullname_with_scope()
<< ", index is " << origin_pair.second << "; cur kernel is " << cur_pair.first->fullname_with_scope()
<< ", index is " << cur_pair.second;
AnfAlgo::SetOutputAddr(origin_node_output_addr, cur_pair.second, cur_pair.first.get());
// Update the reference count of device address.
@ -293,7 +300,7 @@ void DeviceAddressUtils::UpdateDeviceAddress(const session::AnfWithOutIndex &cur
origin_node_output_addr->ResetRefCount();
} else {
MS_LOG(INFO) << "No need update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
<< ", index is " << origin_pair.second << ", cur kernel is " << cur_pair.first->fullname_with_scope()
<< ", index is " << origin_pair.second << "; cur kernel is " << cur_pair.first->fullname_with_scope()
<< ", index is " << cur_pair.second;
}
}

View File

@ -107,8 +107,9 @@ void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
if (device_tensor_store_keys_.size() > 0) {
const auto &device_tensor_store_node = device_tensor_store_keys_[0].second;
MS_EXCEPTION_IF_NULL(device_tensor_store_node);
input_device_tensor_[0] =
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_node.get(), input_device_context->GetDeviceType());
input_device_tensor_[0] = DeviceTensorStore::GetInstance()
.Fetch(device_tensor_store_node.get(), input_device_context->GetDeviceType())
.get();
if (input_device_tensor_[0] == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_node->fullname_with_scope() +
@ -116,8 +117,9 @@ void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
output_device_tensor_[0] =
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_node.get(), output_device_context->GetDeviceType());
output_device_tensor_[0] = DeviceTensorStore::GetInstance()
.Fetch(device_tensor_store_node.get(), output_device_context->GetDeviceType())
.get();
if (output_device_tensor_[0] == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_node->fullname_with_scope() +

View File

@ -172,42 +172,44 @@ void ValueTupleToValue(const ValuePtr &value, std::vector<ValuePtr> *const value
}
}
void UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto ref_node_map = graph->GetRefMap();
for (auto iter : ref_node_map) {
auto &output_pair = iter.first;
auto &input_pair = iter.second;
auto &ref_node = output_pair.first;
auto output_index = output_pair.second;
auto &input_node = input_pair.first;
auto input_node_output_index = input_pair.second;
auto input_addr = AnfAlgo::GetMutableOutputAddr(input_node, input_node_output_index, false);
auto ref_node_output_addr = AnfAlgo::GetMutableOutputAddr(ref_node, output_index, false);
// Just compare shared_ptr of two DeviceAddress.
// The ptr of DeviceAddress may still be nullptr.
if (input_addr != ref_node_output_addr) {
// AnfAlgo::SetOutputAddr cannot update the device_address of frontend Tensor
// if the output of RefNode is used by subsequent nodes.
// Because the frontend Tensor is copied from backend Tensor and the shared_ptr of Tensor is different.
MS_EXCEPTION_IF_NULL(input_addr);
if (input_addr->GetMutablePtr() == nullptr) {
AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get());
} else {
ref_node_output_addr->set_ptr(input_addr->GetMutablePtr());
}
}
}
}
void UpdateGraphsRefNodeAddress(const std::vector<KernelGraphPtr> &graphs) {
// The device address of input ref node may be modified by input tensor, so need update the device address of ref node.
void UpdateDeviceAddressByRefInputNode(const std::vector<KernelGraphPtr> &graphs,
const std::set<AnfNode *> &modified_input_nodes) {
for (const auto &graph : graphs) {
// The DeviceAddress of the graph parameter has been updated.
// The output address of RefNode needs to be consistent with the address of parameter.
MS_EXCEPTION_IF_NULL(graph);
if (!graph->is_graph_run_mode()) {
UpdateRefNodeOutputDeviceAddress(graph);
// The DeviceAddress of the graph parameter has been updated.
if (graph->is_graph_run_mode()) {
continue;
}
for (auto &iter : graph->GetRefMap()) {
auto &output_pair = iter.first;
auto &input_pair = iter.second;
MS_EXCEPTION_IF_NULL(output_pair.first);
MS_EXCEPTION_IF_NULL(input_pair.first);
if (modified_input_nodes.count(input_pair.first.get()) == 0) {
continue;
}
// The output device tensor of ref node actor can't be changed in the running, and only the ptr of output device
// address can be modified. And need set `ref_count` to `SIZE_MAX` for avoiding clean. So only support the
// persistent device tensor.
if (!IsPersistentDeviceTensor(input_pair.first)) {
MS_LOG(EXCEPTION) << "The input parameter: " << input_pair.first->fullname_with_scope()
<< " isn't the ref parameter which used by the ref node: "
<< output_pair.first->fullname_with_scope();
}
MS_LOG(INFO) << "Update the ptr of ref node: " << output_pair.first->fullname_with_scope()
<< " by the modified ref input parameter: " << input_pair.first->fullname_with_scope();
auto ref_node_output_addr = AnfAlgo::GetMutableOutputAddr(output_pair.first, output_pair.second, false);
MS_EXCEPTION_IF_NULL(ref_node_output_addr);
const auto &front_input_node = graph->GetFrontAnfByBackendAnf(input_pair.first);
const auto &input_addr =
DeviceTensorStore::GetInstance().Fetch(front_input_node.get(), ref_node_output_addr->GetDeviceType());
MS_EXCEPTION_IF_NULL(input_addr);
ref_node_output_addr->set_ptr(input_addr->GetMutablePtr());
ref_node_output_addr->set_original_ref_count(SIZE_MAX);
ref_node_output_addr->ResetRefCount();
}
}
}
@ -312,7 +314,7 @@ void DataPrepareActor::UpdateDynamicShape(const AnfNodePtr &input_node, const Te
void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_node, const TensorPtr &input_tensor,
const KernelGraphPtr &graph,
const DeviceContext *device_context) const {
const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(input_tensor);
MS_EXCEPTION_IF_NULL(graph);
@ -343,8 +345,9 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no
tensor_address->type_id() != device_address->type_id()))) {
return;
}
if (tensor_address != nullptr) {
// Assign tensor address to input data node and set `ref_count` to `SIZE_MAX` for avoiding clean
if ((tensor_address != nullptr) && (tensor_address != device_address)) {
// Assign tensor address to input data node and set `ref_count` to `SIZE_MAX` for avoiding clean.
(void)address_modified_input_nodes_.insert(input_node.get());
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
tensor_address->SetNodeIndex(input_node, 0);
tensor_address->set_original_ref_count(SIZE_MAX);
@ -401,7 +404,11 @@ void DataPrepareActor::PrepareData(const std::vector<std::vector<TensorPtr>> &in
return;
}
MS_EXCEPTION_IF_NULL(graph_compiler_info_);
UpdateGraphsRefNodeAddress(graph_compiler_info_->graphs_);
if (!address_modified_input_nodes_.empty()) {
UpdateDeviceAddressByRefInputNode(graph_compiler_info_->graphs_, address_modified_input_nodes_);
address_modified_input_nodes_.clear();
}
// Debug actor is blocked, must wait debug actor callback message to process continue.
if (debug_aid_ != nullptr && strategy_ == GraphExecutionStrategy::kPipeline) {
SendDebugReq(context);
@ -724,7 +731,7 @@ void DataPrepareActor::CopyDataFromDeviceTensorStore(const AnfNodePtr &front_nod
// Prepare the device data for persistent device tensor of weight node from host tensor.
void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &front_node,
const TensorPtr &tensor, const DeviceContext *device_context,
OpContext<DeviceTensor> *const context) const {
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(backend_node);
MS_EXCEPTION_IF_NULL(front_node);
MS_EXCEPTION_IF_NULL(device_context);
@ -755,20 +762,7 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
}
MS_EXCEPTION_IF_NULL(host_tensor_address);
if (host_tensor_address->GetDeviceType() == device_tensor->GetDeviceType()) {
// In the scenario of training + inference , the device address of the weight node can not be changed when
// multi-graphs sink mode is set.
if (device_tensor->is_ptr_persisted() && (host_tensor_address != device_tensor)) {
if (!Copy(device_tensor.get(), host_tensor_address.get())) {
std::string error_info = "Sync data error.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
}
host_tensor_address = device_tensor;
tensor->set_device_address(device_tensor);
} else {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
}
} else {
if (host_tensor_address->GetDeviceType() != device_tensor->GetDeviceType()) {
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->GetDeviceType()
<< ", device tensor type:" << device_tensor->GetDeviceType();
// The fake heterogeneous scenario.
@ -778,6 +772,20 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
tensor->set_device_address(device_tensor);
is_need_sync = true;
}
} else if (host_tensor_address != device_tensor) {
// In the scenario of training + inference , the device address of the weight node can not be changed when
// multi-graphs sink mode is set.
if (device_tensor->is_ptr_persisted()) {
if (!Copy(device_tensor.get(), host_tensor_address.get())) {
std::string error_info = "Sync data error.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
}
host_tensor_address = device_tensor;
tensor->set_device_address(device_tensor);
} else {
(void)address_modified_input_nodes_.insert(backend_node.get());
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
}
}
}
// Maybe the same host_tensor_address corresponds to the different front_node in shared weight scene,

View File

@ -22,6 +22,7 @@
#include <memory>
#include <utility>
#include <map>
#include <set>
#include "utils/hash_map.h"
#include "runtime/graph_scheduler/graph_compiler.h"
#include "runtime/graph_scheduler/actor/actor_common.h"
@ -77,7 +78,7 @@ class DataPrepareActor : public DebugAwareActor {
void UpdateDynamicShape(const AnfNodePtr &input_node, const TensorPtr &input_tensor) const;
void UpdateDeviceAddressForDataNode(const AnfNodePtr &input_node, const TensorPtr &input_tensor,
const KernelGraphPtr &graph, const DeviceContext *device_context) const;
const KernelGraphPtr &graph, const DeviceContext *device_context);
void PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors,
OpContext<DeviceTensor> *const context);
@ -86,7 +87,7 @@ class DataPrepareActor : public DebugAwareActor {
// Prepare the device data for persistent device tensor of weight node from host tensor.
void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &front_node, const TensorPtr &tensor,
const DeviceContext *device_context, OpContext<DeviceTensor> *const context) const;
const DeviceContext *device_context, OpContext<DeviceTensor> *const context);
// Prepare the device data for persistent device tensor of value node.
void PrepareDataForValueNode(const ValueNodePtr &node, const AnfNodePtr &front_node,
const DeviceContext *device_context, OpContext<DeviceTensor> *const context) const;
@ -132,6 +133,9 @@ class DataPrepareActor : public DebugAwareActor {
std::vector<size_t> total_size_list_;
std::vector<const DeviceContext *> continuous_memory_device_contexts_;
std::vector<std::vector<TensorPtr>> init_tensors_;
// Record the address modified input ndoes to refresh the ref node.
std::set<AnfNode *> address_modified_input_nodes_;
}; // namespace runtime
using DataPrepareActorPtr = std::shared_ptr<DataPrepareActor>;

View File

@ -525,8 +525,9 @@ void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context)
}
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
device_contexts_[0]->GetDeviceType());
auto device_tensor = DeviceTensorStore::GetInstance()
.Fetch(device_tensor_store_key.second.get(), device_contexts_[0]->GetDeviceType())
.get();
if (device_tensor == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() +
@ -630,6 +631,20 @@ bool KernelActor::LaunchKernel(OpContext<DeviceTensor> *const) {
}
}
// Check the address of ref node.
for (const auto &ref : kernel_info_->out_in_ref_map()) {
size_t input_index = ref.second;
size_t output_index = ref.first;
MS_EXCEPTION_IF_CHECK_FAIL((launch_info_.inputs_.size() > input_index), "The ref input index is out of range.");
MS_EXCEPTION_IF_CHECK_FAIL((launch_info_.outputs_.size() > output_index), "The ref output index is out of range.");
MS_EXCEPTION_IF_NULL(launch_info_.inputs_[input_index]);
MS_EXCEPTION_IF_NULL(launch_info_.outputs_[output_index]);
if (launch_info_.inputs_[input_index]->addr != launch_info_.outputs_[output_index]->addr) {
MS_LOG(ERROR) << "Input address and output address are not equal of ref kernel actor: " << GetAID().Name();
return false;
}
}
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
MS_LOG(DEBUG) << "Begin launch kernel of actor: " << GetAID().Name();
auto ret = device_contexts_[0]->kernel_executor_->LaunchKernel(
@ -666,7 +681,7 @@ void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {
// Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of
// the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of the
// current actor is in front of SendMemoryAllocReq of the next actor. One is to reuse the memory more fully, the
// current actor is in front of SendMemoryAllocReq of the next actor. One is to reuse the memory more fully, the
// other is to ensure the execution order and avoid the illegal memory timing problem.
if (memory_free_list_.size() > 0) {
SendMemoryFreeReq(context);

View File

@ -126,8 +126,9 @@ void SuperKernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const con
// Check device tensor store.
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
auto input_device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
device_contexts_[0]->GetDeviceType());
auto input_device_tensor = DeviceTensorStore::GetInstance()
.Fetch(device_tensor_store_key.second.get(), device_contexts_[0]->GetDeviceType())
.get();
// Ge backend maybe nullptr.
if (input_device_tensor == nullptr) {
MS_LOG(WARNING) << "Failed get device tensor for node:" << device_tensor_store_key.second->DebugString()

View File

@ -83,7 +83,7 @@ class DeviceTensorStore {
}
}
DeviceTensor *Fetch(AnfNode *key, DeviceTensorType value_type) const {
DeviceTensorPtr Fetch(AnfNode *key, DeviceTensorType value_type) const {
MS_EXCEPTION_IF_NULL(key);
std::shared_lock<std::shared_mutex> lock(map_mutex_);
const auto &iter = device_tensors_.find(key);
@ -91,7 +91,7 @@ class DeviceTensorStore {
for (const auto &device_tensor : iter->second) {
MS_EXCEPTION_IF_NULL(device_tensor);
if (device_tensor->GetDeviceType() == value_type) {
return device_tensor.get();
return device_tensor;
}
}
}

View File

@ -568,7 +568,6 @@ void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const Devic
DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context, graph);
DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);
DeviceAddressUtils::UpdateDeviceAddressForRefNode(graph);
MS_LOG(INFO) << "Status record: end create device address. graph id: " << graph->graph_id();
}

View File

@ -507,6 +507,7 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info
const auto &actor_set = Build(graph_compiler_info);
MS_EXCEPTION_IF_NULL(actor_set);
CacheGraphOutputToActor(graph_compiler_info);
UpdateDeviceAddressByRefInternalParameter(graph_compiler_info);
Link(actor_set.get(), graph_compiler_info);
DumpActor(actor_set.get(), graph_compiler_info);
@ -780,6 +781,68 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
}
}
void GraphScheduler::UpdateDeviceAddressByRefInternalParameter(const GraphCompilerInfo &graph_compiler_info) {
for (const auto &graph : graph_compiler_info.graphs_) {
MS_EXCEPTION_IF_NULL(graph);
// The graph run mode no need update.
if (graph->is_graph_run_mode()) {
continue;
}
for (const auto &ref_node_pair : graph->GetRefMap()) {
auto &cur_node_pair = ref_node_pair.first;
auto &origin_node_pair = ref_node_pair.second;
MS_EXCEPTION_IF_NULL(cur_node_pair.first);
MS_EXCEPTION_IF_NULL(origin_node_pair.first);
// Only the internal parameter need update.
if (!IsInternalParameter(origin_node_pair.first, graph)) {
continue;
}
// Get the real origin node by the internal parameter.
auto front_output_with_index = graph->GetOriginFrontNodeByInternalParameter(origin_node_pair.first);
MS_EXCEPTION_IF_NULL(front_output_with_index.first);
if (graph_output_to_actor_.count(front_output_with_index) == 0) {
MS_LOG(EXCEPTION) << "Can't find graph output by front node:" << front_output_with_index.first->DebugString();
}
auto real_origin_node_pair = graph_output_to_actor_[front_output_with_index].second;
real_origin_node_pair =
common::AnfAlgo::VisitKernelWithReturnType(real_origin_node_pair.first, real_origin_node_pair.second, false);
MS_EXCEPTION_IF_NULL(real_origin_node_pair.first);
auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(cur_node_pair.first, cur_node_pair.second, false);
MS_EXCEPTION_IF_NULL(cur_node_output_addr);
auto origin_node_output_addr =
AnfAlgo::GetMutableOutputAddr(real_origin_node_pair.first, real_origin_node_pair.second, false);
// The persistent device tensor need fetch the device address by device type from the device tensor store.
if (IsPersistentDeviceTensor(real_origin_node_pair.first)) {
front_output_with_index = common::AnfAlgo::VisitKernelWithReturnType(front_output_with_index.first,
front_output_with_index.second, false);
origin_node_output_addr = DeviceTensorStore::GetInstance().Fetch(front_output_with_index.first.get(),
cur_node_output_addr->GetDeviceType());
}
MS_EXCEPTION_IF_NULL(origin_node_output_addr);
// The device address can't be updated through heterogeneous address.
if ((origin_node_output_addr.get() == cur_node_output_addr.get()) ||
(origin_node_output_addr->GetDeviceType() != cur_node_output_addr->GetDeviceType())) {
continue;
}
MS_LOG(INFO) << "Update device address by internal parameter: ref origin kernel is "
<< real_origin_node_pair.first->fullname_with_scope() << ", index is "
<< real_origin_node_pair.second << "; cur kernel is " << cur_node_pair.first->fullname_with_scope()
<< ", index is " << cur_node_pair.second << "; internal parameter is "
<< origin_node_pair.first->DebugString();
AnfAlgo::SetOutputAddr(origin_node_output_addr, cur_node_pair.second, cur_node_pair.first.get());
// Update the reference count of device address.
cur_node_output_addr->DecreaseOriginalRefCount();
cur_node_output_addr->ResetRefCount();
origin_node_output_addr->IncreaseOriginalRefCount();
origin_node_output_addr->ResetRefCount();
}
}
}
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(actor_set);
std::vector<AbstractActor *> auto_monad_actors;

View File

@ -125,6 +125,9 @@ class BACKEND_EXPORT GraphScheduler {
// Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
// previous graph and the head of next graph.
void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
// The input and output of ref node may be in the different subgraphs, so need the global subgraphs info to update the
// device address of ref node.
void UpdateDeviceAddressByRefInternalParameter(const GraphCompilerInfo &graph_compiler_info);
// The processing of actors linking.
// 1. The processing of linking data arrows.

View File

@ -163,7 +163,8 @@ DeviceTensor *MemorySwapNodeScheduler::GetNodeOutputDeviceTensor(
if (front_node == nullptr || front_node->isa<CNode>()) {
return device_address;
}
auto real_device_address = DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceType());
auto real_device_address =
DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceType()).get();
if (real_device_address != nullptr) {
return real_device_address;
}

View File

@ -61,7 +61,7 @@ class DeviceSync {
}
}
void DecreaseOriginalRefCount() {
if (original_ref_count_ > 0) {
if ((original_ref_count_ < SIZE_MAX) && (original_ref_count_ > 0)) {
original_ref_count_--;
}
}