forked from mindspore-Ecosystem/mindspore
actor runtime support the inplace optimizer
This commit is contained in:
parent
a2dc98f972
commit
f619e85647
|
@ -18,7 +18,6 @@
|
|||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include "ir/anf.h"
|
||||
|
@ -50,8 +49,6 @@ constexpr size_t kNopNodeInputSize = 2;
|
|||
constexpr size_t kNopNodeRealInputIndex = 1;
|
||||
constexpr size_t kReturnDataIndex = 1;
|
||||
|
||||
using PrimitiveSet = std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual>;
|
||||
|
||||
const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
|
||||
|
||||
bool IsShapeDynamic(const abstract::ShapePtr &shape) {
|
||||
|
@ -68,15 +65,6 @@ bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
|
|||
return (prim && prim_set.find(prim) != prim_set.end());
|
||||
}
|
||||
|
||||
bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr || cnode->size() == 0) {
|
||||
return false;
|
||||
}
|
||||
return IsOneOfPrimitive(cnode->inputs().at(kAnfPrimitiveIndex), prim_set);
|
||||
}
|
||||
|
||||
bool IsRealKernelCNode(const CNodePtr &cnode) {
|
||||
static const PrimitiveSet virtual_prims = {
|
||||
prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary,
|
||||
|
@ -284,7 +272,7 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
|
|||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
|
||||
return VisitKernelWithReturnType(cnode->input(kUpdateStateStateInput), index, visit_nop_node, return_types);
|
||||
}
|
||||
if (IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) {
|
||||
if (AnfAlgo::IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) {
|
||||
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, visit_nop_node, return_types);
|
||||
}
|
||||
if (opt::IsNopNode(cnode) && visit_nop_node) {
|
||||
|
@ -2087,5 +2075,13 @@ bool AnfRuntimeAlgorithm::IsTensorBroadcast(const std::vector<size_t> &lhs, cons
|
|||
return false;
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr || cnode->size() == 0) {
|
||||
return false;
|
||||
}
|
||||
return IsOneOfPrimitive(cnode->inputs().at(kAnfPrimitiveIndex), prim_set);
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "base/base.h"
|
||||
|
@ -37,6 +38,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
using PrimitiveSet = std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual>;
|
||||
using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>;
|
||||
using DeviceAddress = device::DeviceAddress;
|
||||
using DeviceAddressPtr = device::DeviceAddressPtr;
|
||||
|
@ -293,6 +295,7 @@ class AnfRuntimeAlgorithm {
|
|||
}
|
||||
return result;
|
||||
}
|
||||
static bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -50,6 +50,8 @@
|
|||
#include "utils/info.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "pipeline/jit/prim_bprop_optimizer.h"
|
||||
#include "mindrt/src/actor/actormgr.h"
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/constants.h"
|
||||
#include "ps/util.h"
|
||||
|
@ -1199,6 +1201,7 @@ void ClearResAtexit() {
|
|||
parse::Parser::CleanParserResource();
|
||||
parse::CleanDataClassToClassMap();
|
||||
trace::ClearTraceStack();
|
||||
ActorMgr::GetActorMgrRef()->TerminateAll();
|
||||
}
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -87,6 +87,14 @@ bool IsKernelActor(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsSkippedKernelActor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsKernelActor(node) && AnfAlgo::IsInplaceNode(node, "skip")) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<ValueNode>()) {
|
||||
|
@ -97,6 +105,5 @@ bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -51,6 +51,8 @@ bool IsDeviceQueueDSActor(const AnfNodePtr &node);
|
|||
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph = nullptr,
|
||||
const TensorPtr &tensor = nullptr);
|
||||
bool IsKernelActor(const AnfNodePtr &node);
|
||||
// The skip kernel doesn't run, it exists in the inplace optimizer.
|
||||
bool IsSkippedKernelActor(const AnfNodePtr &node);
|
||||
|
||||
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
|
||||
// related to the input of this kernel graph.
|
||||
|
|
|
@ -185,7 +185,7 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cont
|
|||
if (host_queue_->IsEmpty()) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Host data queue is empty.");
|
||||
}
|
||||
auto &host_tensors = host_queue_->PullData();
|
||||
auto &host_tensors = host_queue_->Pull();
|
||||
auto &device_tensors = buffers_.back();
|
||||
if (host_tensors.size() != device_tensors.size()) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context),
|
||||
|
@ -204,7 +204,7 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cont
|
|||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "SyncHostToDevice failed.");
|
||||
}
|
||||
}
|
||||
host_queue_->PopData();
|
||||
host_queue_->Pop();
|
||||
|
||||
// Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of
|
||||
// the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of
|
||||
|
|
|
@ -187,6 +187,50 @@ void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, con
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// Collect the inplace groups.
|
||||
std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
|
||||
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
if (!AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
|
||||
continue;
|
||||
}
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto inplace_group_attr = primitive->GetAttr("inplace_group");
|
||||
MS_EXCEPTION_IF_NULL(inplace_group_attr);
|
||||
auto group_id = GetValue<uint32_t>(inplace_group_attr);
|
||||
inplace_groups[group_id].emplace_back(kernel);
|
||||
}
|
||||
|
||||
const size_t kMinInplaceGroupSize = 2;
|
||||
for (const auto &inplace_group : inplace_groups) {
|
||||
auto &group_nodes = inplace_group.second;
|
||||
if (group_nodes.size() < kMinInplaceGroupSize) {
|
||||
continue;
|
||||
}
|
||||
// Get the device address of the first node in the inplace group.
|
||||
auto node_primitive = AnfAlgo::GetCNodePrimitive(group_nodes[0]);
|
||||
MS_EXCEPTION_IF_NULL(node_primitive);
|
||||
auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
|
||||
// Update the device address of other nodes using device address of the first node in the inplace group.
|
||||
for (size_t i = 1; i < group_nodes.size(); ++i) {
|
||||
auto &group_node = group_nodes[i];
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(group_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
|
||||
AnfAlgo::SetOutputAddr(device_address, index, group_node.get());
|
||||
// Update the reference count of device address.
|
||||
device_address->IncreaseOriginalRefCount();
|
||||
device_address->ResetRefCount();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GraphCompiler::set_device_context(DeviceContext *device_context) {
|
||||
|
@ -285,6 +329,7 @@ void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph) const {
|
|||
CreateValueNodeDeviceAddress(device_context_, graph);
|
||||
CreateKernelOutputDeviceAddress(device_context_, graph);
|
||||
CreateKernelWorkspaceDeviceAddress(device_context_, graph);
|
||||
UpdateDeviceAddressForInplaceNode(graph);
|
||||
}
|
||||
|
||||
void GraphCompiler::GetParamAndOutputIndex(
|
||||
|
|
|
@ -465,7 +465,7 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
|
|||
if (host_data_source_actor != nullptr) {
|
||||
const auto &host_tensor_queue = FetchHostQueue(actor_set->name_);
|
||||
MS_EXCEPTION_IF_NULL(host_tensor_queue);
|
||||
host_tensor_queue->PushData(host_tensors);
|
||||
host_tensor_queue->Push(host_tensors);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -590,7 +590,7 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto execution_order = graph->execution_order();
|
||||
for (auto &kernel : execution_order) {
|
||||
if (!IsKernelActor(kernel)) {
|
||||
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel))) {
|
||||
continue;
|
||||
}
|
||||
const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
|
||||
|
@ -717,7 +717,7 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto execution_order = graph->execution_order();
|
||||
for (auto &kernel : execution_order) {
|
||||
if (IsKernelActor(kernel)) {
|
||||
if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) {
|
||||
auto kernel_actor =
|
||||
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_);
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
|
@ -786,13 +786,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const ActorSet *actor_
|
|||
std::string actor_name = actor_set->name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
|
||||
const auto &from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(FetchActor(actor_name));
|
||||
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else if (IsHostQueueDSActor(from_kernel, graph)) {
|
||||
bool tensor_has_device_address =
|
||||
tensor != nullptr && (std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address()) != nullptr);
|
||||
if (tensor_has_device_address) {
|
||||
return;
|
||||
}
|
||||
|
||||
} else if (IsHostQueueDSActor(from_kernel, graph, tensor)) {
|
||||
// Link the data arrows of host queue data source actor.
|
||||
std::string actor_name = actor_set->name_ + "_HostDSActor";
|
||||
const auto &from_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name));
|
||||
|
@ -808,7 +802,8 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const ActorSet *actor_
|
|||
const auto devcie_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
|
||||
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, devcie_tensor_store_key.get());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid from kernel: " << from_kernel->fullname_with_scope();
|
||||
// May exist the from kernel that no need link in the pynative mode.
|
||||
MS_LOG(DEBUG) << "Invalid from kernel: " << from_kernel->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -910,9 +905,24 @@ void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_
|
|||
void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
if (IsSkippedKernelActor(from_kernel_with_output_idx.first)) {
|
||||
auto real_kernel_with_index = AnfAlgo::GetPrevNodeOutput(from_kernel_with_output_idx.first, 0);
|
||||
MS_EXCEPTION_IF_NULL(real_kernel_with_index.first);
|
||||
LinkControlArrowBySkippedNode(to_actor, from_kernel_with_output_idx.first);
|
||||
|
||||
// Update the from kernel info by the real node info.
|
||||
MS_LOG(INFO) << "Link data arrow for inplace node, aggregate node: "
|
||||
<< to_kernel_with_input_idx.first->fullname_with_scope()
|
||||
<< ", aggregate input index: " << to_kernel_with_input_idx.second
|
||||
<< ", skip node: " << from_kernel_with_output_idx.first->fullname_with_scope()
|
||||
<< ", real node: " << real_kernel_with_index.first->fullname_with_scope();
|
||||
from_kernel_with_output_idx.first = real_kernel_with_index.first;
|
||||
from_kernel_with_output_idx.second = real_kernel_with_index.second;
|
||||
from_actor = dynamic_cast<KernelActor *>(FetchActor(from_kernel_with_output_idx.first->fullname_with_scope()));
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto from_kernel = from_kernel_with_output_idx.first;
|
||||
MS_EXCEPTION_IF_NULL(from_kernel);
|
||||
auto from_output_index = from_kernel_with_output_idx.second;
|
||||
|
@ -1029,7 +1039,8 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
|
|||
return;
|
||||
}
|
||||
// Find the real input node, include the monad node and make tuple node.
|
||||
const std::vector<PrimitivePtr> &return_types = {prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
|
||||
const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
|
||||
prim::kPrimMakeTuple};
|
||||
const auto &input_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(from_node, 0, false, return_types);
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first);
|
||||
if (!input_kernel_with_output_idx.first->isa<CNode>()) {
|
||||
|
@ -1037,40 +1048,65 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
|
|||
}
|
||||
const auto &input_cnode = input_kernel_with_output_idx.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
|
||||
// Get the real depend input by monad node which needs to link the control arrow.
|
||||
AnfNodePtr real_depend_input = nullptr;
|
||||
if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimUpdateState)) {
|
||||
real_depend_input = input_cnode->input(kUpdateStateRealInput);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimLoad)) {
|
||||
real_depend_input = input_cnode->input(kLoadStateInput);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) {
|
||||
// Make tuple node needs to be expanded.
|
||||
// Make tuple node needs to be expanded.
|
||||
if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) {
|
||||
for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
|
||||
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i));
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(real_depend_input);
|
||||
if (!real_depend_input->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
// The monad node and make tuple node need recursion.
|
||||
if (AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimUpdateState) ||
|
||||
AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimLoad) ||
|
||||
AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimMakeTuple)) {
|
||||
LinkControlArrowByAutoMonad(to_actor, real_depend_input);
|
||||
return;
|
||||
// Get the real depend input by monad node which needs to link the control arrow.
|
||||
std::vector<AnfNodePtr> real_depend_inputs;
|
||||
if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimDepend)) {
|
||||
real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
|
||||
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimUpdateState)) {
|
||||
for (size_t i = kUpdateStateRealInput; i < input_cnode->inputs().size(); ++i) {
|
||||
real_depend_inputs.push_back(input_cnode->input(i));
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimLoad)) {
|
||||
real_depend_inputs.push_back(input_cnode->input(kLoadStateInput));
|
||||
}
|
||||
|
||||
// Link the control arrow between the kernel actors.
|
||||
const auto &from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_input->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
|
||||
to_actor->input_controls_num_++;
|
||||
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
|
||||
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
|
||||
for (const auto &real_depend_input : real_depend_inputs) {
|
||||
// The monad node and make tuple node need recursion.
|
||||
if (AnfAlgo::IsOneOfPrimitiveCNode(real_depend_input, recursion_prims)) {
|
||||
LinkControlArrowByAutoMonad(to_actor, real_depend_input);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!IsKernelActor(real_depend_input)) {
|
||||
continue;
|
||||
}
|
||||
// Link the control arrow between the kernel actors.
|
||||
const auto &from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_input->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_LOG(INFO) << "Link control arrow by auto monad, from actor: " << from_actor->GetAID().Name()
|
||||
<< ", to actor: " << to_actor->GetAID().Name();
|
||||
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
|
||||
to_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node) {
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(skipped_node);
|
||||
auto to_aid = to_actor->GetAID();
|
||||
|
||||
// Link the control arrow from all the inputs of skipped node to the user of skipped node.
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(skipped_node);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(skipped_node, i, false);
|
||||
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
|
||||
auto from_actor = dynamic_cast<KernelActor *>(FetchActor(kernel_with_index.first->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_LOG(INFO) << "Link control arrow by skipped node: " << skipped_node->fullname_with_scope()
|
||||
<< ", from actor: " << from_actor->GetAID().Name() << ", to actor: " << to_actor->GetAID().Name();
|
||||
from_actor->output_control_arrows_.emplace_back(to_aid);
|
||||
to_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
|
@ -190,6 +191,8 @@ class GraphScheduler {
|
|||
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
|
||||
GraphExecutionStrategy strategy);
|
||||
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node);
|
||||
// The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
|
||||
void LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node);
|
||||
void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);
|
||||
|
||||
|
|
|
@ -33,13 +33,13 @@ class HostTensorQueue {
|
|||
HostTensorQueue() = default;
|
||||
virtual ~HostTensorQueue() = default;
|
||||
|
||||
void PushData(const std::vector<TensorPtr> &tensors) { buffers_.push(tensors); }
|
||||
void Push(const std::vector<TensorPtr> &tensors) { buffers_.push(tensors); }
|
||||
|
||||
const std::vector<TensorPtr> &PullData() { return buffers_.front(); }
|
||||
const std::vector<TensorPtr> &Pull() { return buffers_.front(); }
|
||||
|
||||
bool IsEmpty() const { return buffers_.empty(); }
|
||||
|
||||
void PopData() { buffers_.pop(); }
|
||||
void Pop() { buffers_.pop(); }
|
||||
|
||||
private:
|
||||
std::queue<std::vector<TensorPtr>> buffers_;
|
||||
|
|
Loading…
Reference in New Issue