!28655 unified runtime support the ref node in the control flow

Merge pull request !28655 from limingqi107/bug_fix4
This commit is contained in:
i-robot 2022-01-07 09:29:20 +00:00 committed by Gitee
commit 1b08f35ef1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 133 additions and 2 deletions

View File

@ -183,6 +183,26 @@ void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) {
ofs << "\n";
}
void DumpFormalParameterDeviceTensor(const ControlActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
const auto &formal_parameter_device_tensors = actor->ref_formal_parameter_device_tensors();
if (formal_parameter_device_tensors.empty()) {
return;
}
ofs << "\t\tref_formal_parameter_device_tensors:" << formal_parameter_device_tensors.size() << "\n ";
for (const auto &formal_parameter_device_tensor : formal_parameter_device_tensors) {
for (const auto &device_tensor : formal_parameter_device_tensor.second) {
MS_EXCEPTION_IF_NULL(device_tensor);
auto ref_node = device_tensor->GetNodeIndex();
MS_EXCEPTION_IF_NULL(ref_node.first);
ofs << "\t\t\tref_position:" << formal_parameter_device_tensor.first
<< "\tref_node_name:" << ref_node.first->fullname_with_scope()
<< "\tref_node_debug_name:" << ref_node.first->DebugString() << "\n";
}
}
}
void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
DumpAbstractActor(actor, ofs);
@ -229,6 +249,8 @@ void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) {
ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
}
}
DumpFormalParameterDeviceTensor(actor, ofs);
}
void DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) {
@ -309,6 +331,14 @@ void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) {
}
}
}
const auto &is_need_copy_device_tensors = actor->is_need_copy_device_tensors();
if (is_need_copy_device_tensors.size() > 0) {
ofs << "\t\twhether_need_copy_device_tensors:" << is_need_copy_device_tensors.size() << "\n ";
for (size_t i = 0; i < is_need_copy_device_tensors.size(); ++i) {
ofs << "\t\t\tdevice_tensor_position:" << i << "\tis_need_copy:" << is_need_copy_device_tensors[i] << "\n";
}
}
}
void DumpStackActor(const StackActor *actor, std::ofstream &ofs) {
@ -449,11 +479,11 @@ void DumpCopyActors(const std::vector<CopyActorPtr> &actors, std::ofstream &ofs)
}
void DumpControlActors(const ControlActorSetPtr &control_actor_set, std::ofstream &ofs) {
ofs << "\n\n[Control actors]\n";
if (control_actor_set == nullptr) {
return;
}
ofs << "\n\n[Control actors]\n";
DumpEntranceActors(control_actor_set->entrance_actors_, ofs);
DumpSwitchActors(control_actor_set->switch_actors_, ofs);
DumpGatherActors(control_actor_set->gather_actors_, ofs);

View File

@ -333,6 +333,54 @@ void ControlActor::EraseInput(const OpContext<DeviceTensor> *context) {
}
}
void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow,
const AnfNodePtr &, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(output_data);
MS_EXCEPTION_IF_NULL(data_arrow);
MS_EXCEPTION_IF_NULL(context);
const auto &data = output_data->data_;
MS_EXCEPTION_IF_NULL(data);
auto formal_parameter_position = data_arrow->from_output_index_;
// Has no the ref formal parameter.
if (ref_formal_parameter_device_tensors_.count(formal_parameter_position) == 0) {
return;
}
if (data->GetMutablePtr() == nullptr) {
std::string error_info =
"The address of the " + std::to_string(formal_parameter_position) + "position formal parameter is nullptr.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
if (data->ref_count() != SIZE_MAX) {
std::string error_info = "The ref count of the " + std::to_string(formal_parameter_position) +
"position formal parameter is wrong:" + std::to_string(data->ref_count());
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// Foreach the device tensors to set the ptr from data.
for (auto &device_tensor : ref_formal_parameter_device_tensors_[formal_parameter_position]) {
MS_EXCEPTION_IF_NULL(device_tensor);
if ((device_tensor.get() == data) || (device_tensor->GetMutablePtr() == data->GetMutablePtr())) {
continue;
}
auto real_parameter = device_tensor->GetNodeIndex();
MS_EXCEPTION_IF_NULL(real_parameter.first);
if ((device_tensor->GetSize() != data->GetSize()) || (device_tensor->format() != data->format()) ||
(device_tensor->type_id() != data->type_id())) {
std::string error_info =
"The address of the " + std::to_string(formal_parameter_position) +
"position formal parameter can not be set to real parameter:" + real_parameter.first->DebugString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
device_tensor->set_ptr(data->GetMutablePtr());
MS_LOG(DEBUG) << "Set the ptr: " << data->GetMutablePtr()
<< " for the ref real parameter: " << real_parameter.first->DebugString()
<< " in the actor: " << GetAID().Name();
}
}
void ControlActor::SendOutput(OpContext<DeviceTensor> *const context) {
// Send branch id.
for (const auto &branch_id_arrow : output_branch_id_arrows_) {

View File

@ -21,6 +21,7 @@
#include <string>
#include <memory>
#include <map>
#include <set>
#include <unordered_map>
#include <stack>
#include <queue>
@ -70,6 +71,9 @@ class ControlActor : public MemoryAwareActor {
const std::unordered_map<size_t, OpPartialPtr> &local_partials() const { return local_partials_; }
const std::vector<AID> &input_partial_arrow_aids() const { return input_partial_arrow_aids_; }
const std::vector<AID> &input_branch_id_arrow_aids() const { return input_branch_id_arrow_aids_; }
const std::map<size_t, std::set<DeviceTensorPtr>> &ref_formal_parameter_device_tensors() const {
return ref_formal_parameter_device_tensors_;
}
size_t branch_id() const { return output_branch_id_; }
protected:
@ -89,6 +93,8 @@ class ControlActor : public MemoryAwareActor {
virtual void FetchInput(OpContext<DeviceTensor> *const context);
void Run(OpContext<DeviceTensor> *const context) override;
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow,
const AnfNodePtr &output_node, OpContext<DeviceTensor> *const context) override;
void SendOutput(OpContext<DeviceTensor> *const context) override;
void EraseInput(const OpContext<DeviceTensor> *context) override;
@ -144,6 +150,10 @@ class ControlActor : public MemoryAwareActor {
// Formal parameters for control actor.
std::vector<KernelWithIndex> formal_parameters_;
// The device tensors of backend input nodes corresponding to ref formal parameters, the key is the position index of
// formal parameter. Used to update the ptr of device tensors when receive the real parameters for ref nodes.
std::map<size_t, std::set<DeviceTensorPtr>> ref_formal_parameter_device_tensors_;
// local node for control actor, such as return node for exit actor, switch node for switch actor.
AnfNodePtr node_;
};

View File

@ -51,6 +51,7 @@ class ExitActor : public ControlActor {
const mindspore::HashMap<int, std::vector<DataArrowPtr>> &output_branch_partial_arrows() const {
return output_branch_partial_arrows_;
}
const std::vector<bool> &is_need_copy_device_tensors() const { return is_need_copy_device_tensors_; }
protected:
void FetchInput(OpContext<DeviceTensor> *const context) override;

View File

@ -90,6 +90,26 @@ bool IsControlFlowArrow(const ControlNodeParserPtr &parser, const KernelGraphPtr
(from_node != nullptr && IsPersistentDeviceTensor(from_node)) ||
(from_node != nullptr && parser->IsSameKernelGraphGroup(from_node, graph));
}
// Parameter and ref node can not copy the device tensor.
bool is_need_copy_device_tensor(const AnfNodePtr &backend_node, size_t index) {
MS_EXCEPTION_IF_NULL(backend_node);
if (!backend_node->isa<CNode>()) {
return false;
}
if (HasAbstractRef(backend_node)) {
return false;
}
auto kernel_graph = FetchKernelGraph(backend_node);
MS_EXCEPTION_IF_NULL(kernel_graph);
if (kernel_graph->IsInRefOutputMap({backend_node, index})) {
return false;
}
return true;
}
} // namespace
ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info,
@ -274,7 +294,8 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil
// Get the device contexts of the exit actor's cnode inputs.
const AnfNodePtr &backend_node = node_with_context.second.first.first;
MS_EXCEPTION_IF_NULL(backend_node);
is_need_copy_device_tensors.emplace_back(backend_node->isa<CNode>() ? true : false);
is_need_copy_device_tensors.emplace_back(
is_need_copy_device_tensor(backend_node, node_with_context.second.first.second));
device_contexts.emplace_back(node_with_context.second.second);
}
@ -1155,6 +1176,7 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
to_index = super_kernel_actor->FetchInputNodePosition(input);
(void)sink_input_node_linked.insert(input);
}
AddFormalParameterDeviceTensor(from_actor, from_index, input);
LinkDataArrow(from_actor, to_actor, from_index, to_index);
}
}
@ -1231,6 +1253,22 @@ void ControlNodeScheduler::LinkArrowForRootGraphEntranceActor(const GraphCompile
}
}
void ControlNodeScheduler::AddFormalParameterDeviceTensor(ControlActor *const from_actor, size_t from_index,
const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(input_node);
if (!HasAbstractRef(input_node)) {
return;
}
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
(void)from_actor->ref_formal_parameter_device_tensors_[from_index].insert(device_tensor);
UpdateRefCount(device_tensor.get(), true);
device_tensor->SetNodeIndex(input_node, 0);
}
void ControlNodeScheduler::LinkDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
size_t from_index, size_t to_index, const AnfNodePtr &from_kernel) {
MS_EXCEPTION_IF_NULL(from_actor);

View File

@ -25,6 +25,7 @@
#include <map>
#include <set>
#include <algorithm>
#include <queue>
#include "runtime/framework/actor/actor_set.h"
#include "runtime/framework/graph_compiler.h"
@ -107,6 +108,9 @@ class ControlNodeScheduler {
size_t to_index, int branch_id);
bool IsNoInputActor(const ControlActor *control_actor);
// Fill the device tensors of backend input nodes corresponding to ref formal parameters.
void AddFormalParameterDeviceTensor(ControlActor *const from_actor, size_t from_index, const AnfNodePtr &input_node);
// The id of memory manager actor.
AID memory_manager_aid_;
};