actor runtime support the inplace optimizer

This commit is contained in:
limingqi107 2021-05-26 20:21:27 +08:00
parent a2dc98f972
commit f619e85647
10 changed files with 153 additions and 58 deletions

View File

@ -18,7 +18,6 @@
#include <algorithm>
#include <map>
#include <set>
#include <unordered_set>
#include <functional>
#include <numeric>
#include "ir/anf.h"
@ -50,8 +49,6 @@ constexpr size_t kNopNodeInputSize = 2;
constexpr size_t kNopNodeRealInputIndex = 1;
constexpr size_t kReturnDataIndex = 1;
using PrimitiveSet = std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual>;
const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
bool IsShapeDynamic(const abstract::ShapePtr &shape) {
@ -68,15 +65,6 @@ bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
return (prim && prim_set.find(prim) != prim_set.end());
}
bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() == 0) {
return false;
}
return IsOneOfPrimitive(cnode->inputs().at(kAnfPrimitiveIndex), prim_set);
}
bool IsRealKernelCNode(const CNodePtr &cnode) {
static const PrimitiveSet virtual_prims = {
prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary,
@ -284,7 +272,7 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
return VisitKernelWithReturnType(cnode->input(kUpdateStateStateInput), index, visit_nop_node, return_types);
}
if (IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) {
if (AnfAlgo::IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, visit_nop_node, return_types);
}
if (opt::IsNopNode(cnode) && visit_nop_node) {
@ -2087,5 +2075,13 @@ bool AnfRuntimeAlgorithm::IsTensorBroadcast(const std::vector<size_t> &lhs, cons
return false;
}
bool AnfRuntimeAlgorithm::IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() == 0) {
return false;
}
return IsOneOfPrimitive(cnode->inputs().at(kAnfPrimitiveIndex), prim_set);
}
} // namespace session
} // namespace mindspore

View File

@ -23,6 +23,7 @@
#include <tuple>
#include <utility>
#include <memory>
#include <unordered_set>
#include "ir/anf.h"
#include "ir/dtype.h"
#include "base/base.h"
@ -37,6 +38,7 @@
namespace mindspore {
namespace session {
using PrimitiveSet = std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual>;
using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>;
using DeviceAddress = device::DeviceAddress;
using DeviceAddressPtr = device::DeviceAddressPtr;
@ -293,6 +295,7 @@ class AnfRuntimeAlgorithm {
}
return result;
}
static bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;

View File

@ -50,6 +50,8 @@
#include "utils/info.h"
#include "load_mindir/load_model.h"
#include "pipeline/jit/prim_bprop_optimizer.h"
#include "mindrt/src/actor/actormgr.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/constants.h"
#include "ps/util.h"
@ -1199,6 +1201,7 @@ void ClearResAtexit() {
parse::Parser::CleanParserResource();
parse::CleanDataClassToClassMap();
trace::ClearTraceStack();
ActorMgr::GetActorMgrRef()->TerminateAll();
}
} // namespace pipeline
} // namespace mindspore

View File

@ -87,6 +87,14 @@ bool IsKernelActor(const AnfNodePtr &node) {
return false;
}
bool IsSkippedKernelActor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (IsKernelActor(node) && AnfAlgo::IsInplaceNode(node, "skip")) {
return true;
}
return false;
}
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>()) {
@ -97,6 +105,5 @@ bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
}
return false;
}
} // namespace runtime
} // namespace mindspore

View File

@ -51,6 +51,8 @@ bool IsDeviceQueueDSActor(const AnfNodePtr &node);
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph = nullptr,
const TensorPtr &tensor = nullptr);
bool IsKernelActor(const AnfNodePtr &node);
// The skip kernel doesn't run, it exists in the inplace optimizer.
bool IsSkippedKernelActor(const AnfNodePtr &node);
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
// related to the input of this kernel graph.

View File

@ -185,7 +185,7 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cont
if (host_queue_->IsEmpty()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Host data queue is empty.");
}
auto &host_tensors = host_queue_->PullData();
auto &host_tensors = host_queue_->Pull();
auto &device_tensors = buffers_.back();
if (host_tensors.size() != device_tensors.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context),
@ -204,7 +204,7 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cont
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "SyncHostToDevice failed.");
}
}
host_queue_->PopData();
host_queue_->Pop();
// Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of
// the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of

View File

@ -187,6 +187,50 @@ void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, con
}
}
}
void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
// Collect the inplace groups.
std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
const std::vector<CNodePtr> &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
if (!AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
continue;
}
auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
MS_EXCEPTION_IF_NULL(primitive);
auto inplace_group_attr = primitive->GetAttr("inplace_group");
MS_EXCEPTION_IF_NULL(inplace_group_attr);
auto group_id = GetValue<uint32_t>(inplace_group_attr);
inplace_groups[group_id].emplace_back(kernel);
}
const size_t kMinInplaceGroupSize = 2;
for (const auto &inplace_group : inplace_groups) {
auto &group_nodes = inplace_group.second;
if (group_nodes.size() < kMinInplaceGroupSize) {
continue;
}
// Get the device address of the first node in the inplace group.
auto node_primitive = AnfAlgo::GetCNodePrimitive(group_nodes[0]);
MS_EXCEPTION_IF_NULL(node_primitive);
auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
MS_EXCEPTION_IF_NULL(device_address);
// Update the device address of other nodes using device address of the first node in the inplace group.
for (size_t i = 1; i < group_nodes.size(); ++i) {
auto &group_node = group_nodes[i];
auto prim = AnfAlgo::GetCNodePrimitive(group_node);
MS_EXCEPTION_IF_NULL(prim);
auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
AnfAlgo::SetOutputAddr(device_address, index, group_node.get());
// Update the reference count of device address.
device_address->IncreaseOriginalRefCount();
device_address->ResetRefCount();
}
}
}
} // namespace
void GraphCompiler::set_device_context(DeviceContext *device_context) {
@ -285,6 +329,7 @@ void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph) const {
CreateValueNodeDeviceAddress(device_context_, graph);
CreateKernelOutputDeviceAddress(device_context_, graph);
CreateKernelWorkspaceDeviceAddress(device_context_, graph);
UpdateDeviceAddressForInplaceNode(graph);
}
void GraphCompiler::GetParamAndOutputIndex(

View File

@ -465,7 +465,7 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
if (host_data_source_actor != nullptr) {
const auto &host_tensor_queue = FetchHostQueue(actor_set->name_);
MS_EXCEPTION_IF_NULL(host_tensor_queue);
host_tensor_queue->PushData(host_tensors);
host_tensor_queue->Push(host_tensors);
}
}
@ -590,7 +590,7 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (!IsKernelActor(kernel)) {
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel))) {
continue;
}
const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
@ -717,7 +717,7 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (IsKernelActor(kernel)) {
if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) {
auto kernel_actor =
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(kernel_actor);
@ -786,13 +786,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const ActorSet *actor_
std::string actor_name = actor_set->name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
const auto &from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(FetchActor(actor_name));
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsHostQueueDSActor(from_kernel, graph)) {
bool tensor_has_device_address =
tensor != nullptr && (std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address()) != nullptr);
if (tensor_has_device_address) {
return;
}
} else if (IsHostQueueDSActor(from_kernel, graph, tensor)) {
// Link the data arrows of host queue data source actor.
std::string actor_name = actor_set->name_ + "_HostDSActor";
const auto &from_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name));
@ -808,7 +802,8 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const ActorSet *actor_
const auto devcie_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, devcie_tensor_store_key.get());
} else {
MS_LOG(EXCEPTION) << "Invalid from kernel: " << from_kernel->fullname_with_scope();
// May exist the from kernel that no need link in the pynative mode.
MS_LOG(DEBUG) << "Invalid from kernel: " << from_kernel->fullname_with_scope();
}
}
@ -910,9 +905,24 @@ void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_
void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
if (IsSkippedKernelActor(from_kernel_with_output_idx.first)) {
auto real_kernel_with_index = AnfAlgo::GetPrevNodeOutput(from_kernel_with_output_idx.first, 0);
MS_EXCEPTION_IF_NULL(real_kernel_with_index.first);
LinkControlArrowBySkippedNode(to_actor, from_kernel_with_output_idx.first);
// Update the from kernel info by the real node info.
MS_LOG(INFO) << "Link data arrow for inplace node, aggregate node: "
<< to_kernel_with_input_idx.first->fullname_with_scope()
<< ", aggregate input index: " << to_kernel_with_input_idx.second
<< ", skip node: " << from_kernel_with_output_idx.first->fullname_with_scope()
<< ", real node: " << real_kernel_with_index.first->fullname_with_scope();
from_kernel_with_output_idx.first = real_kernel_with_index.first;
from_kernel_with_output_idx.second = real_kernel_with_index.second;
from_actor = dynamic_cast<KernelActor *>(FetchActor(from_kernel_with_output_idx.first->fullname_with_scope()));
}
MS_EXCEPTION_IF_NULL(from_actor);
auto from_kernel = from_kernel_with_output_idx.first;
MS_EXCEPTION_IF_NULL(from_kernel);
auto from_output_index = from_kernel_with_output_idx.second;
@ -1029,7 +1039,8 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
return;
}
// Find the real input node, include the monad node and make tuple node.
const std::vector<PrimitivePtr> &return_types = {prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
prim::kPrimMakeTuple};
const auto &input_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(from_node, 0, false, return_types);
MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first);
if (!input_kernel_with_output_idx.first->isa<CNode>()) {
@ -1037,40 +1048,65 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
}
const auto &input_cnode = input_kernel_with_output_idx.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
// Get the real depend input by monad node which needs to link the control arrow.
AnfNodePtr real_depend_input = nullptr;
if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimUpdateState)) {
real_depend_input = input_cnode->input(kUpdateStateRealInput);
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimLoad)) {
real_depend_input = input_cnode->input(kLoadStateInput);
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) {
// Make tuple node needs to be expanded.
// Make tuple node needs to be expanded.
if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) {
for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i));
}
return;
} else {
return;
}
MS_EXCEPTION_IF_NULL(real_depend_input);
if (!real_depend_input->isa<CNode>()) {
return;
}
// The monad node and make tuple node need recursion.
if (AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimUpdateState) ||
AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimLoad) ||
AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimMakeTuple)) {
LinkControlArrowByAutoMonad(to_actor, real_depend_input);
return;
// Get the real depend input by monad node which needs to link the control arrow.
std::vector<AnfNodePtr> real_depend_inputs;
if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimDepend)) {
real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimUpdateState)) {
for (size_t i = kUpdateStateRealInput; i < input_cnode->inputs().size(); ++i) {
real_depend_inputs.push_back(input_cnode->input(i));
}
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimLoad)) {
real_depend_inputs.push_back(input_cnode->input(kLoadStateInput));
}
// Link the control arrow between the kernel actors.
const auto &from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_input->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(from_actor);
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++;
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
for (const auto &real_depend_input : real_depend_inputs) {
// The monad node and make tuple node need recursion.
if (AnfAlgo::IsOneOfPrimitiveCNode(real_depend_input, recursion_prims)) {
LinkControlArrowByAutoMonad(to_actor, real_depend_input);
continue;
}
if (!IsKernelActor(real_depend_input)) {
continue;
}
// Link the control arrow between the kernel actors.
const auto &from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_input->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(from_actor);
MS_LOG(INFO) << "Link control arrow by auto monad, from actor: " << from_actor->GetAID().Name()
<< ", to actor: " << to_actor->GetAID().Name();
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++;
}
}
void GraphScheduler::LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node) {
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(skipped_node);
auto to_aid = to_actor->GetAID();
// Link the control arrow from all the inputs of skipped node to the user of skipped node.
auto input_num = AnfAlgo::GetInputTensorNum(skipped_node);
for (size_t i = 0; i < input_num; ++i) {
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(skipped_node, i, false);
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
auto from_actor = dynamic_cast<KernelActor *>(FetchActor(kernel_with_index.first->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(from_actor);
MS_LOG(INFO) << "Link control arrow by skipped node: " << skipped_node->fullname_with_scope()
<< ", from actor: " << from_actor->GetAID().Name() << ", to actor: " << to_actor->GetAID().Name();
from_actor->output_control_arrows_.emplace_back(to_aid);
to_actor->input_controls_num_++;
}
}
void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,

View File

@ -22,6 +22,7 @@
#include <memory>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include <map>
#include <algorithm>
#include <fstream>
@ -190,6 +191,8 @@ class GraphScheduler {
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
GraphExecutionStrategy strategy);
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node);
// The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
void LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node);
void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);

View File

@ -33,13 +33,13 @@ class HostTensorQueue {
HostTensorQueue() = default;
virtual ~HostTensorQueue() = default;
void PushData(const std::vector<TensorPtr> &tensors) { buffers_.push(tensors); }
void Push(const std::vector<TensorPtr> &tensors) { buffers_.push(tensors); }
const std::vector<TensorPtr> &PullData() { return buffers_.front(); }
const std::vector<TensorPtr> &Pull() { return buffers_.front(); }
bool IsEmpty() const { return buffers_.empty(); }
void PopData() { buffers_.pop(); }
void Pop() { buffers_.pop(); }
private:
std::queue<std::vector<TensorPtr>> buffers_;