Unified runtime support mixed precision.

This commit is contained in:
gaoyong10 2021-06-08 14:50:15 +08:00
parent f86d707126
commit bf6528645c
15 changed files with 172 additions and 63 deletions

View File

@ -612,7 +612,12 @@ AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, Kernel
if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) { if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState}); pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
} }
for (const auto &parameter : parameters) {
for (size_t i = 0; i < parameters.size(); ++i) {
const auto &parameter = parameters[i];
// In control flow, if the input of the cnode is a call node, it will be processed as a make_tuple input,
// which needs to be linked when processing the internal node.
graph->CacheInternalParameterToFrontNode(parameter, {node, i});
auto valid_inputs = graph->MutableValidInputs(); auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs); MS_EXCEPTION_IF_NULL(valid_inputs);
auto graph_inputs = graph->MutableInputs(); auto graph_inputs = graph->MutableInputs();

View File

@ -44,6 +44,8 @@ bool IsDeviceQueueDSActor(const AnfNodePtr &node) {
return false; return false;
} }
bool IsSwitchActor(const AnfNodePtr &node) { return AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch); }
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph, const TensorPtr &tensor, bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph, const TensorPtr &tensor,
const std::vector<AnfNodePtr> &host_parameters) { const std::vector<AnfNodePtr> &host_parameters) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);

View File

@ -59,6 +59,7 @@ bool IsDeviceQueueDSActor(const AnfNodePtr &node);
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph = nullptr, bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph = nullptr,
const TensorPtr &tensor = nullptr, const std::vector<AnfNodePtr> &host_parameters = {}); const TensorPtr &tensor = nullptr, const std::vector<AnfNodePtr> &host_parameters = {});
bool IsKernelActor(const AnfNodePtr &node); bool IsKernelActor(const AnfNodePtr &node);
bool IsSwitchActor(const AnfNodePtr &node);
// The skip kernel doesn't run, it exists in the inplace optimizer. // The skip kernel doesn't run, it exists in the inplace optimizer.
bool IsSkippedKernelActor(const AnfNodePtr &node); bool IsSkippedKernelActor(const AnfNodePtr &node);

View File

@ -58,8 +58,8 @@ void GatherActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
if (CheckLaunchCondition(context)) { if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context); FetchInputDeviceTensor(context);
EraseInput(context);
SendOutput(context); SendOutput(context);
input_op_datas_.erase(context->sequential_num_);
} }
} }
@ -86,8 +86,8 @@ void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph,
// 2. Output the kernel actor. // 2. Output the kernel actor.
for (const auto parameters : func_iter->second) { for (const auto parameters : func_iter->second) {
if (parameters.size() != graph_inputs.size()) { if (parameters.size() != graph_inputs.size()) {
MS_LOG(EXCEPTION) << "Parameters num is invalid, current:" << parameters.size() MS_LOG(EXCEPTION) << "Parameters num is invalid, current:" << parameters.size() << " need:" << graph_inputs.size()
<< " need:" << graph_inputs.size(); << " func_graph:" << func_iter->first->ToString();
} }
for (size_t i = 0; i < parameters.size(); ++i) { for (size_t i = 0; i < parameters.size(); ++i) {
@ -133,20 +133,11 @@ void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph,
void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const { void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
// Send output data. // Branch arrow and result arrow must be executed before the data arrow and control arrow, otherwise the output
for (auto &output_data : output_data_) { // actor may receive the loop count message first and cause the output to be abnormal.
MS_EXCEPTION_IF_NULL(output_data);
Async(output_data->op_id_, &OpActor::RunOpData, output_data, context);
}
// Send output control.
auto source_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_control_arrows_) {
Async(output_control, &OpActor::RunOpControl, source_aid, context);
}
if (branch_id_ > kInvalidBranchID) { if (branch_id_ > kInvalidBranchID) {
Async(loop_count_aid_, &LoopCountActor::CollectBranchId, branch_id_, context); Async(loop_count_aid_, &LoopCountActor::CollectBranchId, branch_id_, context);
Async(output_aid_, &OutputActor::CollectBranchId, branch_id_, context);
} }
// Send graph output result. // Send graph output result.
@ -160,9 +151,22 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
input_device_tensors_[from_index]) { input_device_tensors_[from_index]) {
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second, Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
result_arrow->to_input_index_, context); result_arrow->to_input_index_, context);
break;
} }
} }
} }
// Send output data.
for (auto &output_data : output_data_) {
MS_EXCEPTION_IF_NULL(output_data);
Async(output_data->op_id_, &OpActor::RunOpData, output_data, context);
}
// Send output control.
auto source_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_control_arrows_) {
Async(output_control, &OpActor::RunOpControl, source_aid, context);
}
} }
void GatherActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) { void GatherActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
@ -209,5 +213,23 @@ bool GatherActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
return true; return true;
} }
void GatherActor::EraseInput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto ret = input_op_datas_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input data failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
if (input_controls_num_ != 0) {
auto ret = input_op_controls_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input controls failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
}
} // namespace runtime } // namespace runtime
} // namespace mindspore } // namespace mindspore

View File

@ -37,8 +37,9 @@ namespace runtime {
// Gather actor is the entrance of sub funcgraph. Graph input is sent to it and sent to other actors by gather actor. // Gather actor is the entrance of sub funcgraph. Graph input is sent to it and sent to other actors by gather actor.
class GatherActor : public OpActor<DeviceTensor> { class GatherActor : public OpActor<DeviceTensor> {
public: public:
GatherActor(const std::string &name, const std::vector<AnfNodePtr> &parameters, const AID loop_count_aid) GatherActor(const std::string &name, const std::vector<AnfNodePtr> &parameters, const AID loop_count_aid,
: OpActor(name), data_nodes_(parameters), loop_count_aid_(loop_count_aid) {} const AID output_aid)
: OpActor(name), data_nodes_(parameters), loop_count_aid_(loop_count_aid), output_aid_(output_aid) {}
~GatherActor() override = default; ~GatherActor() override = default;
// Get the index of the parameter, the data_node needs to be the front node. // Get the index of the parameter, the data_node needs to be the front node.
@ -61,6 +62,8 @@ class GatherActor : public OpActor<DeviceTensor> {
// Check whether satisfy the condition for launch. // Check whether satisfy the condition for launch.
bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const; bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const;
void SendOutput(OpContext<DeviceTensor> *context) const; void SendOutput(OpContext<DeviceTensor> *context) const;
// Erase input data and input controls when finish gather launch.
void EraseInput(OpContext<DeviceTensor> *context);
// The device tensors for launch. // The device tensors for launch.
std::vector<DeviceTensor *> input_device_tensors_; std::vector<DeviceTensor *> input_device_tensors_;
@ -83,6 +86,7 @@ class GatherActor : public OpActor<DeviceTensor> {
size_t input_controls_num_{0}; size_t input_controls_num_{0};
const AID loop_count_aid_; const AID loop_count_aid_;
const AID output_aid_;
// Cache unique output data by output index to modify the output data effectively. // Cache unique output data by output index to modify the output data effectively.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_; std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;

View File

@ -30,8 +30,8 @@ void LoopCountActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *c
auto sequential_num = context->sequential_num_; auto sequential_num = context->sequential_num_;
input_op_controls_[sequential_num].emplace_back(input_control); input_op_controls_[sequential_num].emplace_back(input_control);
if (CheckExecuteCondition(context)) { if (CheckLoopCountIncreaseCondition(context)) {
Execute(context); IncreaseLoopCount(context);
} }
} }
@ -39,8 +39,8 @@ void LoopCountActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
branch_id_ = branch_id; branch_id_ = branch_id;
if (CheckExecuteCondition(context)) { if (CheckLoopCountIncreaseCondition(context)) {
Execute(context); IncreaseLoopCount(context);
} }
} }
@ -53,7 +53,7 @@ void LoopCountActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
SendOutput(context); SendOutput(context);
} }
void LoopCountActor::Execute(OpContext<DeviceTensor> *context) { void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_; auto sequential_num = context->sequential_num_;
auto ret = input_op_controls_.erase(sequential_num); auto ret = input_op_controls_.erase(sequential_num);
@ -100,7 +100,7 @@ void LoopCountActor::SendOutput(OpContext<DeviceTensor> *context) {
} }
} }
bool LoopCountActor::CheckExecuteCondition(OpContext<DeviceTensor> *context) { bool LoopCountActor::CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_; auto sequential_num = context->sequential_num_;
if (branch_id_ == kInvalidBranchID) { if (branch_id_ == kInvalidBranchID) {
@ -108,8 +108,7 @@ bool LoopCountActor::CheckExecuteCondition(OpContext<DeviceTensor> *context) {
} }
if (branch_id_ >= SizeToInt(branch_id_to_input_controls_num_.size())) { if (branch_id_ >= SizeToInt(branch_id_to_input_controls_num_.size())) {
MS_LOG(ERROR) << "Branch id is invalid, id:" << branch_id_ MS_LOG(ERROR) << "Branch id is invalid, id:" << branch_id_;
<< " total branch num:" << branch_id_to_input_controls_num_.size();
} }
return input_op_controls_[sequential_num].size() == branch_id_to_input_controls_num_[branch_id_]; return input_op_controls_[sequential_num].size() == branch_id_to_input_controls_num_[branch_id_];
} }

View File

@ -60,10 +60,10 @@ class LoopCountActor : public DebugAwareActor {
private: private:
friend class GraphScheduler; friend class GraphScheduler;
void Execute(OpContext<DeviceTensor> *context); void IncreaseLoopCount(OpContext<DeviceTensor> *context);
void SendOutput(OpContext<DeviceTensor> *context); void SendOutput(OpContext<DeviceTensor> *context);
bool CheckExecuteCondition(OpContext<DeviceTensor> *context); bool CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *context);
// The loop count is constant, the current count is increased after each step running finished. // The loop count is constant, the current count is increased after each step running finished.
size_t loop_count_; size_t loop_count_;
size_t current_count_; size_t current_count_;
@ -87,7 +87,7 @@ class LoopCountActor : public DebugAwareActor {
// When the result of the graph is sent to the output actor, the gather actor of the graph needs // When the result of the graph is sent to the output actor, the gather actor of the graph needs
// to send branch_id to the output actor to determine the corresponding weight. // to send branch_id to the output actor to determine the corresponding weight.
int branch_id_{kInvalidBranchID}; int branch_id_{kMainBranchID};
}; };
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>; using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;

View File

@ -46,17 +46,21 @@ TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index,
void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *context) { void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
if (branch_id_ == kInvalidBranchID) {
MS_LOG(EXCEPTION) << "Invalid branch id for output actor.";
}
current_count_ = loop_count; current_count_ = loop_count;
if (loop_count_ == current_count_) { if (loop_count_ == current_count_) {
if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) { if (current_outputs_num_ + device_tensor_store_keys_[branch_id_].size() != outputs_num_) {
std::string error_info = "The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) + std::string error_info =
"The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) +
", the current outputs num: " + std::to_string(current_outputs_num_) + ", the current outputs num: " + std::to_string(current_outputs_num_) +
", the device tensor store num: " + std::to_string(device_tensor_store_keys_.size()); ", the device tensor store num: " + std::to_string(device_tensor_store_keys_[branch_id_].size());
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
} }
// Because device tensor store can't send data, so fetch the output result of device tensor store in running end. // Because device tensor store can't send data, so fetch the output result of device tensor store in running end.
for (const auto &device_tensor_store_key : device_tensor_store_keys_) { for (const auto &device_tensor_store_key : device_tensor_store_keys_[branch_id_]) {
if (device_tensor_store_key.first >= outputs_.size()) { if (device_tensor_store_key.first >= outputs_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range."); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range.");
} }
@ -96,6 +100,11 @@ void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *c
} }
} }
void OutputActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
branch_id_ = branch_id;
}
void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position, void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
OpContext<DeviceTensor> *context) { OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(output_node); MS_EXCEPTION_IF_NULL(output_node);
@ -115,7 +124,7 @@ void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_ind
// Save the output nodes to clear the device tensor in the running end. // Save the output nodes to clear the device tensor in the running end.
output_nodes_[output_position] = KernelWithIndex(output_node, output_index); output_nodes_[output_position] = KernelWithIndex(output_node, output_index);
// There is no loop count actor in step mode, need trigger call CollectLoopCount to replace old output device tensors. // There is no loop count actor in step mode, need trigger call CollectLoopCount to replace old output device tensors.
if (!need_loop_count_ && (current_outputs_num_ + device_tensor_store_keys_.size() == outputs_num_)) { if (!need_loop_count_ && (current_outputs_num_ + device_tensor_store_keys_[branch_id_].size() == outputs_num_)) {
CollectLoopCount(++current_count_, context); CollectLoopCount(++current_count_, context);
} }
} }

View File

@ -22,6 +22,8 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include "runtime/framework/control_node_parser.h"
#include "runtime/framework/device_tensor_store.h" #include "runtime/framework/device_tensor_store.h"
#include "runtime/framework/actor/actor_common.h" #include "runtime/framework/actor/actor_common.h"
#include "runtime/hardware/device_context.h" #include "runtime/hardware/device_context.h"
@ -47,6 +49,7 @@ class OutputActor : public OpActor<DeviceTensor> {
outputs_.resize(outputs_num); outputs_.resize(outputs_num);
output_nodes_.resize(outputs_num); output_nodes_.resize(outputs_num);
device_contexts_.resize(outputs_num); device_contexts_.resize(outputs_num);
device_tensor_store_keys_[kMainBranchID] = std::vector<std::pair<size_t, AnfNodePtr>>();
} }
~OutputActor() override = default; ~OutputActor() override = default;
@ -57,6 +60,8 @@ class OutputActor : public OpActor<DeviceTensor> {
void CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position, void CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
OpContext<DeviceTensor> *context); OpContext<DeviceTensor> *context);
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context);
std::vector<TensorPtr> &outputs() { return outputs_; } std::vector<TensorPtr> &outputs() { return outputs_; }
private: private:
@ -74,9 +79,13 @@ class OutputActor : public OpActor<DeviceTensor> {
size_t outputs_num_; size_t outputs_num_;
size_t current_outputs_num_; size_t current_outputs_num_;
bool need_loop_count_; bool need_loop_count_;
int branch_id_{kMainBranchID};
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store. // Pair<branch_id, <index, node>> points to the dependent device tensor store, branch_id is the output branch id.
std::vector<std::pair<size_t, AnfNodePtr>> device_tensor_store_keys_; // In general, the branch id is 0, which means there is only one output branch in the actor set. When there are
// multiple possible output branches in the actor set, different branch ids correspond to their own related nodes.
// The index is the position of node in the output, node is the key of the device tensor store.
std::unordered_map<size_t, std::vector<std::pair<size_t, AnfNodePtr>>> device_tensor_store_keys_;
}; };
using OutputActorPtr = std::shared_ptr<OutputActor>; using OutputActorPtr = std::shared_ptr<OutputActor>;

View File

@ -44,6 +44,7 @@ void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
// When all the inputs are collected, then allocate memory and callback launch. // When all the inputs are collected, then allocate memory and callback launch.
if (CheckLaunchCondition(context)) { if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context); FetchInputDeviceTensor(context);
EraseInput(context);
SendOutput(context); SendOutput(context);
} }
} }
@ -238,6 +239,25 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
} }
} }
void SwitchActor::EraseInput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto ret = input_op_datas_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input data failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
if (input_controls_num_ != 0) {
auto ret = input_op_controls_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input controls failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
}
void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) { void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context); Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context);
} }

View File

@ -87,6 +87,8 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
// Fetch the args of switch branch. // Fetch the args of switch branch.
void FetchInputDeviceTensor(OpContext<DeviceTensor> *context); void FetchInputDeviceTensor(OpContext<DeviceTensor> *context);
void SendOutput(OpContext<DeviceTensor> *context); void SendOutput(OpContext<DeviceTensor> *context);
// Erase input data and input controls when finish switch launch.
void EraseInput(OpContext<DeviceTensor> *context);
void SendMemoryFreeReq(OpContext<DeviceTensor> *context); void SendMemoryFreeReq(OpContext<DeviceTensor> *context);
// All inputs of the switch actor, excluding weight and tensor. // All inputs of the switch actor, excluding weight and tensor.
@ -107,6 +109,8 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
const AID memory_manager_aid_; const AID memory_manager_aid_;
// The dependent input data number. // The dependent input data number.
size_t input_datas_num_{0}; size_t input_datas_num_{0};
// The dependent input controls number.
size_t input_controls_num_{0};
CNodePtr node_; CNodePtr node_;
// The output_data_ corresponds to the output_data_arrows_ one by one. // The output_data_ corresponds to the output_data_arrows_ one by one.

View File

@ -983,12 +983,15 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
} }
} }
auto loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor"; const auto &loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor";
auto actor = FetchActor(loop_count_actor_name); const auto &loop_count_actor = FetchActor(loop_count_actor_name);
if (actor == nullptr) { MS_EXCEPTION_IF_NULL(loop_count_actor);
MS_LOG(EXCEPTION) << "Cannot find loop count actor by name:" << loop_count_actor_name; const auto &output_actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
} const auto &output_actor = FetchActor(output_actor_name);
auto gather_actor = std::make_shared<GatherActor>(actor_name, parameters, actor->GetAID()); MS_EXCEPTION_IF_NULL(output_actor);
auto gather_actor =
std::make_shared<GatherActor>(actor_name, parameters, loop_count_actor->GetAID(), output_actor->GetAID());
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.origin_parameters_order_, gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.origin_parameters_order_,
graph_compiler_info.front_to_backend_parameters_, graph_compiler_info.front_to_backend_parameters_,
graph_compiler_info.func_graph_to_parameters_, front_to_backend_kernel); graph_compiler_info.func_graph_to_parameters_, front_to_backend_kernel);
@ -1066,7 +1069,7 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_output_node.get()); to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_output_node.get());
return; return;
} }
if (graph_output_to_actor_.count(front_output_with_index) == 0) { if (graph_output_to_actor_.count(front_output_with_index) == 0 && (!IsSwitchActor(front_output_node))) {
MS_LOG(EXCEPTION) << "Can't find actor by front node:" << front_output_node->fullname_with_scope() MS_LOG(EXCEPTION) << "Can't find actor by front node:" << front_output_node->fullname_with_scope()
<< ", internal parameter:" << internal_parameter->fullname_with_scope(); << ", internal parameter:" << internal_parameter->fullname_with_scope();
} }
@ -1076,6 +1079,9 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
auto from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor_pair.first); auto from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor_pair.first);
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_kernel_, actor_pair.second); auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_kernel_, actor_pair.second);
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx); LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsSwitchActor(front_output_node)) {
const auto &from_actor = dynamic_cast<SwitchActor *>(FetchActor(front_output_node->fullname_with_scope()));
MS_LOG(ERROR) << "Need link to switch actor:" << from_actor->GetAID();
} else if (IsKernelActor(front_output_node)) { } else if (IsKernelActor(front_output_node)) {
auto from_actor = dynamic_cast<KernelActor *>(actor_pair.first); auto from_actor = dynamic_cast<KernelActor *>(actor_pair.first);
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->kernel_, actor_pair.second); auto from_kernel_with_output_idx = KernelWithIndex(from_actor->kernel_, actor_pair.second);
@ -1447,13 +1453,15 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
if (iter == graph_compiler_info.origin_outputs_order_.end()) { if (iter == graph_compiler_info.origin_outputs_order_.end()) {
continue; continue;
} }
to_actor->device_contexts_[iter->second] = graph_compiler_info.device_contexts_[number - 1];
to_actor->device_contexts_[iter->second.second] = graph_compiler_info.device_contexts_[number - 1];
// The device tensor of graph out need be taken over by host tensor, so set the max reference count. // The device tensor of graph out need be taken over by host tensor, so set the max reference count.
UpdateRefCount(output_with_index.first, output_with_index.second, true); UpdateRefCount(output_with_index.first, output_with_index.second, true);
// The graph output is from device tensor store. // The graph output is from device tensor store.
if (IsPersistentDeviceTensor(output_with_index.first)) { if (IsPersistentDeviceTensor(output_with_index.first)) {
to_actor->device_tensor_store_keys_.emplace_back(iter->second, output_with_index.first); to_actor->device_tensor_store_keys_[iter->second.first].emplace_back(iter->second.second,
output_with_index.first);
continue; continue;
} }
@ -1462,7 +1470,7 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
const auto &from_actor = const auto &from_actor =
dynamic_cast<KernelActor *>(FetchActor(output_with_index.first->fullname_with_scope())); dynamic_cast<KernelActor *>(FetchActor(output_with_index.first->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(from_actor); MS_EXCEPTION_IF_NULL(from_actor);
auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), iter->second); auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), iter->second.second);
from_actor->output_result_arrows_.emplace_back(op_arrow); from_actor->output_result_arrows_.emplace_back(op_arrow);
continue; continue;
} }
@ -1492,7 +1500,7 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
} }
} }
MS_EXCEPTION_IF_NULL(from_actor); MS_EXCEPTION_IF_NULL(from_actor);
auto op_arrow = std::make_shared<DataArrow>(from_actor_output_index, to_actor->GetAID(), iter->second); auto op_arrow = std::make_shared<DataArrow>(from_actor_output_index, to_actor->GetAID(), iter->second.second);
from_actor->output_result_arrows_.emplace_back(op_arrow); from_actor->output_result_arrows_.emplace_back(op_arrow);
} }
} }
@ -1644,7 +1652,8 @@ void GraphScheduler::LinkDataArrowByCallInput(const GraphCompilerInfo &graph_com
auto from_actor = iter->second; auto from_actor = iter->second;
auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), to_index); auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), to_index);
from_actor->output_data_arrows_.emplace_back(op_arrow); from_actor->output_data_arrows_.emplace_back(op_arrow);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, output_with_index.second, false);
UpdateRefCount(device_tensor.get(), true);
} else if (output_with_index.first->isa<Parameter>()) { } else if (output_with_index.first->isa<Parameter>()) {
// Input is a parameter from gather actor. // Input is a parameter from gather actor.
const auto &actor_name = func_graph->ToString(); const auto &actor_name = func_graph->ToString();
@ -1693,7 +1702,8 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
auto op_arrow = std::make_shared<DataArrow>(input_witch_index.second, to_actor->GetAID(), to_index); auto op_arrow = std::make_shared<DataArrow>(input_witch_index.second, to_actor->GetAID(), to_index);
auto from_actor = front_node_to_actor_[input_witch_index.first]; auto from_actor = front_node_to_actor_[input_witch_index.first];
from_actor->output_data_arrows_.emplace_back(op_arrow); from_actor->output_data_arrows_.emplace_back(op_arrow);
UpdateRefCount(from_actor->kernel_, input_witch_index.second); auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, input_witch_index.second, false);
UpdateRefCount(device_tensor.get(), true);
} else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) { } else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) {
// The actor input is a parameter in host data source actor. // The actor input is a parameter in host data source actor.
std::string actor_name = graph_compiler_info.name_ + "_HostDSActor"; std::string actor_name = graph_compiler_info.name_ + "_HostDSActor";
@ -1717,7 +1727,8 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
auto op_arrow = std::make_shared<DataArrow>(iter->second, to_actor->GetAID(), to_index); auto op_arrow = std::make_shared<DataArrow>(iter->second, to_actor->GetAID(), to_index);
from_actor->output_data_arrows_.emplace_back(op_arrow); from_actor->output_data_arrows_.emplace_back(op_arrow);
UpdateRefCount(from_actor->data_nodes_[iter->second], 0); auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->data_nodes_[iter->second], 0, false);
UpdateRefCount(device_tensor.get(), true);
} else { } else {
MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node); MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node);
} }
@ -1737,9 +1748,15 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_
if (func_graph == nullptr) { if (func_graph == nullptr) {
continue; continue;
} }
if (func_graph->output()->isa<ValueNode>()) {
actor->AddInput(func_graph->output(), 0);
}
auto gather_name = func_graph->ToString(); auto gather_name = func_graph->ToString();
if (actor_name_to_actor_.find(gather_name) == actor_name_to_actor_.end()) { if (actor_name_to_actor_.find(gather_name) == actor_name_to_actor_.end()) {
MS_LOG(EXCEPTION) << "Cannot find gather actor for funcgraph:" << gather_name; MS_LOG(EXCEPTION) << "Cannot find gather actor for funcgraph:" << gather_name
<< ",switch input size:" << actor->input_nodes_.size();
} }
auto to_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[gather_name]); auto to_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[gather_name]);
for (size_t j = 0; j < actor->branch_inputs_pos_[i].size(); ++j) { for (size_t j = 0; j < actor->branch_inputs_pos_[i].size(); ++j) {
@ -1771,17 +1788,20 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
if (graph_compiler_info.control_nodes_.empty()) { if (graph_compiler_info.control_nodes_.empty()) {
return; return;
} }
const auto func_graph = graph_compiler_info.control_nodes_[0]->func_graph();
const auto func_graph = graph_compiler_info.control_nodes_[0]->func_graph();
const auto &loop_count_actor = actor_set->loop_count_actor_.get(); const auto &loop_count_actor = actor_set->loop_count_actor_.get();
const auto &output_actor = actor_set->output_actor_.get();
// If there is only one branch output, set the branch id of the loop count to 0, no need to send the branch id. // If there is only one branch output, set the branch id of the loop count to 0, no need to send the branch id.
auto outputs = graph_compiler_info.front_output_nodes_; auto outputs = graph_compiler_info.front_output_nodes_;
if (outputs.size() == 1) { if (outputs.size() == 1) {
loop_count_actor->branch_id_ = kMainBranchID;
return; return;
} }
loop_count_actor->branch_id_ = kInvalidBranchID;
output_actor->branch_id_ = kInvalidBranchID;
std::vector<FuncGraphPtr> output_func_graphs; std::vector<FuncGraphPtr> output_func_graphs;
for_each(outputs.begin(), outputs.end(), for_each(outputs.begin(), outputs.end(),
[&output_func_graphs](const AnfNodePtr &output) { output_func_graphs.push_back(output->func_graph()); }); [&output_func_graphs](const AnfNodePtr &output) { output_func_graphs.push_back(output->func_graph()); });
@ -1866,8 +1886,11 @@ void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo
if (iter == graph_compiler_info.origin_outputs_order_.end()) { if (iter == graph_compiler_info.origin_outputs_order_.end()) {
continue; continue;
} }
MS_LOG(INFO) << "Link output node:" << AnfAlgo::GetNodeDebugString(origin_output_with_index.first)
<< " branch id:" << iter->second.first << " index:" << iter->second.second
<< " for gather actor:" << gather_actor->GetAID();
auto op_arrow = std::make_shared<DataArrow>(i, to_actor->GetAID(), iter->second); auto op_arrow = std::make_shared<DataArrow>(i, to_actor->GetAID(), iter->second.second);
gather_actor->output_result_arrows_.emplace_back(op_arrow); gather_actor->output_result_arrows_.emplace_back(op_arrow);
const auto &backend_nodes = gather_actor->front_to_backend_parameter_[front_node]; const auto &backend_nodes = gather_actor->front_to_backend_parameter_[front_node];
if (backend_nodes.empty()) { if (backend_nodes.empty()) {
@ -1882,13 +1905,14 @@ void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo
MS_LOG(EXCEPTION) << "Cannot find backend node in host data source actor, node:" MS_LOG(EXCEPTION) << "Cannot find backend node in host data source actor, node:"
<< AnfAlgo::GetNodeDebugString(backend_node); << AnfAlgo::GetNodeDebugString(backend_node);
} }
to_actor->device_contexts_[iter->second] = host_ds_actor->device_contexts_[node_iter - data_nodes.begin()]; to_actor->device_contexts_[iter->second.second] =
host_ds_actor->device_contexts_[node_iter - data_nodes.begin()];
} else { } else {
auto actor_base = FetchActor(backend_node->fullname_with_scope()); auto actor_base = FetchActor(backend_node->fullname_with_scope());
MS_EXCEPTION_IF_NULL(actor_base); MS_EXCEPTION_IF_NULL(actor_base);
auto kernel_actor = dynamic_cast<KernelActor *>(actor_base); auto kernel_actor = dynamic_cast<KernelActor *>(actor_base);
MS_EXCEPTION_IF_NULL(kernel_actor); MS_EXCEPTION_IF_NULL(kernel_actor);
to_actor->device_contexts_[iter->second] = kernel_actor->device_context_; to_actor->device_contexts_[iter->second.second] = kernel_actor->device_context_;
} }
} }
} }
@ -2212,8 +2236,8 @@ void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &of
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_ ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
<< "\toutputs_num:" << actor->outputs_num_ << "\n"; << "\toutputs_num:" << actor->outputs_num_ << "\n";
ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.size() << "\n "; ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.at(kMainBranchID).size() << "\n ";
for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) { for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_.at(kMainBranchID)) {
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second); MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
ofs << "\t\t\toutput_node_position:" << device_tensor_store_key.first ofs << "\t\t\toutput_node_position:" << device_tensor_store_key.first
<< "\toutput_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n"; << "\toutput_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";

View File

@ -42,7 +42,10 @@ namespace runtime {
using mindspore::device::DeviceContext; using mindspore::device::DeviceContext;
using mindspore::session::KernelGraph; using mindspore::session::KernelGraph;
using mindspore::session::KernelWithIndex; using mindspore::session::KernelWithIndex;
using KernelMapPosition = std::map<KernelWithIndex, size_t, session::KernelWithIndexCmp>; // Position of kernel with index, the value pair<branch_id, pos> means the branch id of the kernel and the pos of
// the kernel. Generally, there is only one branch, and the branch id is 0 at this time. In control flow, there
// are multiple branch scenarios, and pos represents the position of the kernel in the branch.
using KernelMapPosition = std::map<KernelWithIndex, std::pair<int, size_t>, session::KernelWithIndexCmp>;
using ActorInfo = std::string; using ActorInfo = std::string;
// The second element of pair represents the output index of op actor corresponding to the graph output node. // The second element of pair represents the output index of op actor corresponding to the graph output node.

View File

@ -512,12 +512,18 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
runtime::KernelMapPosition outputs_order; runtime::KernelMapPosition outputs_order;
size_t outputs_num = 0; size_t outputs_num = 0;
const auto &all_branch_output = ControlNodeParser::FetchAllBranchOutputs(root_graph); const auto &all_branch_output = ControlNodeParser::FetchAllBranchOutputs(root_graph);
for (const auto &branch_output : all_branch_output) { for (int j = 0; j < SizeToInt(all_branch_output.size()); ++j) {
// In general, there is only one output branch, and the branch id is 0 at this time. In the control flow,
// there are multi-branch output scenarios. Different branches may have different weight nodes. When output
// actor run, the corresponding weight node needs to be obtained according to different branches. Therefore,
// the branch of the output nodes needs to be recorded.
const int branch_id = ((all_branch_output.size() == 1 ? runtime::kMainBranchID : (j + runtime::kSubBranchStartID)));
const auto &branch_output = all_branch_output[j];
size_t position = 0; size_t position = 0;
auto outputs = AnfAlgo::GetAllOutputWithIndex(branch_output); auto outputs = AnfAlgo::GetAllOutputWithIndex(branch_output);
outputs_num = outputs.size(); outputs_num = outputs.size();
for (const auto &output : outputs) { for (const auto &output : outputs) {
outputs_order.emplace(output, position++); outputs_order[output] = {branch_id, position++};
} }
} }
@ -562,7 +568,7 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output()); auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
for (const auto &output : outputs) { for (const auto &output : outputs) {
outputs_order.emplace(output, position++); outputs_order[output] = {runtime::kMainBranchID, position++};
} }
} }

View File

@ -43,6 +43,7 @@ using ControlNodeParser = runtime::ControlNodeParser;
using FrontToBackendNodeWithContext = runtime::FrontToBackendNodeWithContext; using FrontToBackendNodeWithContext = runtime::FrontToBackendNodeWithContext;
using FuncGraphToParameter = runtime::FuncGraphToParameter; using FuncGraphToParameter = runtime::FuncGraphToParameter;
using HostParameterToWeight = runtime::HostParameterToWeight; using HostParameterToWeight = runtime::HostParameterToWeight;
enum SwitchCondStatus { enum SwitchCondStatus {
kCondOk = 0, kCondOk = 0,
kCondAlreadyRun, kCondAlreadyRun,