!18941 Fix mulit call in control flow

Merge pull request !18941 from gaoyong10/new_runtime13
This commit is contained in:
i-robot 2021-06-29 06:42:01 +00:00 committed by Gitee
commit 29e7da4c3e
14 changed files with 903 additions and 467 deletions

View File

@ -16,6 +16,7 @@
#include "runtime/framework/actor/gather_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "runtime/framework/actor/loop_count_actor.h"
#include "mindrt/include/async/async.h"
@ -44,7 +45,7 @@ void GatherActor::Init() {
size_t GatherActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
const auto &iter = find(data_nodes_.begin(), data_nodes_.end(), data_node);
if (iter == data_nodes_.end()) {
MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope()
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node)
<< " is not exist in gather actor:" << GetAID();
}
return iter - data_nodes_.begin();
@ -52,9 +53,8 @@ size_t GatherActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
void GatherActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
input_op_datas_[sequential_num].emplace_back(input_data);
input_data_[sequential_num][input_data->index_].push(input_data->data_);
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
@ -63,6 +63,29 @@ void GatherActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
}
}
void GatherActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_op_controls_[sequential_num].emplace_back(input_control);
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
EraseInput(context);
SendOutput(context);
}
}
void GatherActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_branch_ids_[sequential_num] = branch_id;
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
EraseInput(context);
SendOutput(context);
}
}
void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser) {
for (const auto &input : func_graph->get_inputs()) {
// Monad input would not send to gather actor.
@ -76,20 +99,20 @@ void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph, const Co
void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
// Branch arrow and result arrow must be executed before the data arrow and control arrow, otherwise the output
// actor may receive the loop count message first and cause the output to be abnormal.
if (branch_id_ > kInvalidBranchID) {
Async(loop_count_aid_, &LoopCountActor::CollectBranchId, branch_id_, context);
Async(output_aid_, &OutputActor::CollectBranchId, branch_id_, context);
// Send output branch id.
if (find(output_branch_arrows_.begin(), output_branch_arrows_.end(), switch_aid_) != output_branch_arrows_.end()) {
int branch_id = input_branch_id_;
Async(switch_aid_, &SwitchActor::CollectBranchId, branch_id, context);
}
if (find(output_branch_arrows_.begin(), output_branch_arrows_.end(), gather_aid_) != output_branch_arrows_.end()) {
Async(gather_aid_, &GatherActor::CollectBranchId, local_branch_id_, context);
}
// Send graph output result.
// Send output result.
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
size_t from_index = result_arrow->from_output_index_;
const auto &front_node = data_nodes_[from_index];
for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) {
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() ==
input_device_tensors_[from_index]) {
@ -115,15 +138,28 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
void GatherActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end()) {
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter != input_data_.end()) {
for (auto &input_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(input_data);
input_device_tensors_[input_data->index_] = input_data->data_;
input_device_tensors_[input_data.first] = input_data.second.top();
input_data.second.pop();
}
}
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
const auto &device_context = device_contexts_[device_tensor_store_key.first];
MS_EXCEPTION_IF_NULL(device_context);
auto device_tensor =
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context->GetDeviceAddressType());
if (device_tensor == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
", device type:" + std::to_string(static_cast<int>(device_context->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
}
for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
const auto &data = input_device_tensors_[i];
for (auto &output_data : output_data_by_output_index_[i]) {
@ -131,20 +167,31 @@ void GatherActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
output_data->data_ = data;
}
}
if (need_branch_id_input_) {
input_branch_id_ = input_branch_ids_[context->sequential_num_];
}
}
bool GatherActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
// Fetch input data.
if (input_datas_num_ != 0) {
auto data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter == input_op_datas_.end()) {
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter == input_data_.end()) {
return false;
}
if (data_iter->second.size() != input_datas_num_) {
if (data_iter->second.size() + device_tensor_store_keys_.size() != input_datas_num_) {
return false;
}
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
[](const auto &input_stack) { return input_stack.second.empty(); })) {
return false;
}
}
// Fetch input control.
if (input_controls_num_ != 0) {
auto control_iter = input_op_controls_.find(context->sequential_num_);
if (control_iter == input_op_controls_.end()) {
@ -154,19 +201,32 @@ bool GatherActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
return false;
}
}
// Fetch input branch id.
if (need_branch_id_input_) {
auto branch_id_iter = input_branch_ids_.find(context->sequential_num_);
if (branch_id_iter == input_branch_ids_.end()) {
return false;
}
}
return true;
}
void GatherActor::EraseInput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto ret = input_op_datas_.erase(context->sequential_num_);
// Erase input data.
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter != input_data_.end() && std::all_of(data_iter->second.begin(), data_iter->second.end(),
[](const auto &input_data) { return input_data.second.empty(); })) {
auto ret = input_data_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input data failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
// Erase input control.
if (input_controls_num_ != 0) {
auto ret = input_op_controls_.erase(context->sequential_num_);
if (ret == 0) {
@ -174,6 +234,15 @@ void GatherActor::EraseInput(OpContext<DeviceTensor> *context) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
// Erase input branch id.
if (need_branch_id_input_) {
auto ret = input_branch_ids_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input branch id failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -21,6 +21,7 @@
#include <string>
#include <memory>
#include <unordered_map>
#include <stack>
#include <utility>
#include <algorithm>
#include "runtime/framework/device_tensor_store.h"
@ -36,20 +37,37 @@ namespace runtime {
constexpr size_t kReturnInputPos = 1;
// Gather actor is the entrance of sub funcgraph. Graph input is sent to it and sent to other actors by gather actor.
// Gather actor is used in three places:
// 1. Entrance of sub funcgraph
// 2. call node which input0 is a funcgraph
// 3. There is some call nodes in the inputs of kernel graph.
// Gather actor will be used in the control flow. When the subgraph is called, the real parameters need to be put
// together and sent to the subgraph. At the same time, the entry of the subgraph needs to accept input data.
// Special in recursion, general inputs and call inputs of the kernel graph are used in stack mode, it needs to be
// collected at the entrance of the kernel graph.
class GatherActor : public OpActor<DeviceTensor> {
public:
GatherActor(const std::string &name, const std::vector<AnfNodePtr> &parameters, const AID loop_count_aid,
const AID output_aid)
: OpActor(name), data_nodes_(parameters), loop_count_aid_(loop_count_aid), output_aid_(output_aid) {}
GatherActor(const std::string &name, const std::vector<AnfNodePtr> &parameters, const bool need_branch_id_input,
const AID switch_aid, const AID gather_aid, const int branch_id)
: OpActor(name),
data_nodes_(parameters),
need_branch_id_input_(need_branch_id_input),
switch_aid_(switch_aid),
gather_aid_(gather_aid),
local_branch_id_(branch_id) {
device_contexts_.resize(parameters.size());
}
~GatherActor() override = default;
// Get the index of the parameter, the data_node needs to be the front node.
size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;
// The kernel actor run when receive the input data.
// The gather actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override;
// The gather actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
// The gather actor run when receive the input branch id.
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context);
void Init() override;
private:
@ -66,13 +84,33 @@ class GatherActor : public OpActor<DeviceTensor> {
// The device tensors for launch.
std::vector<DeviceTensor *> input_device_tensors_;
// The branch if for current step.
int input_branch_id_;
DeviceContext *device_contexts_;
// Input data.
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
// Input branch ids is used to record the id corresponding receive from gather actor.
// In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
// places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send
// its branch id to the gather actor of the subgraph. Then branch id will be sent by the gather actor to the
// switch actor connected to the output.
std::unordered_map<uuids::uuid *, int> input_branch_ids_;
// Output data.
// Cache unique output data by output index to modify the output data effectively.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<OpData<DeviceTensor> *> output_data_;
// Output arrows.
std::vector<DataArrowPtr> output_result_arrows_;
std::vector<AID> output_branch_arrows_;
// Parameters of sub funcgraph, which is the front node.
std::vector<AnfNodePtr> data_nodes_;
std::vector<DeviceContext *> device_contexts_;
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
// When the output is a parameter of the subgraph, the gather actor needs to send the anfnode to the output actor,
// so all the nodes that may send the device tensor to gather actor are recorded. When the anfnode needs to be sent
@ -83,18 +121,19 @@ class GatherActor : public OpActor<DeviceTensor> {
size_t input_datas_num_{0};
// The dependent input controls number.
size_t input_controls_num_{0};
// Whether it needs to accept the branch id. When the gather actor is the input of the subgraph, it needs to receive
// branch id sent by the subgraph caller, which will be true at this time.
bool need_branch_id_input_;
const AID loop_count_aid_;
const AID output_aid_;
// Actor id that needs to send the branch id to it.
// When the actor is corresponding to call node, the branch id needs to be sent to the input gather actor and output
// switch actor of the called funcgraph. When the actor is the entrance of the funcgraph, the gather actor id is
// empty, just need to send branch id to its output switch actor.
const AID switch_aid_;
const AID gather_aid_;
// Cache unique output data by output index to modify the output data effectively.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<OpData<DeviceTensor> *> output_data_;
// When the result of the graph is sent to the output actor, the gather actor of the graph needs
// to send branch_id to the output actor to determine the corresponding weight.
int branch_id_{kInvalidBranchID};
// The branch id corresponding to the funcgraph to which the gather actor belongs.
int local_branch_id_;
};
using GatherActorPtr = std::shared_ptr<GatherActor>;

View File

@ -77,6 +77,9 @@ void KernelActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_op_datas_[sequential_num].emplace_back(input_data);
if (input_data->data_ == nullptr) {
MS_LOG(EXCEPTION) << "Input data of actor:" << GetAID() << " num:" << input_data->index_ << " is empty";
}
// When all the inputs are collected, then allocate memory and callback launch.
if (CheckLaunchCondition(context)) {
// Infer kernel shape and update abstract info for dynamic shape kernel.
@ -245,7 +248,7 @@ void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
if (device_tensor == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() +
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}

View File

@ -86,16 +86,6 @@ void LoopCountActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *c
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
input_op_controls_[sequential_num].emplace_back(input_control);
if (CheckLoopCountIncreaseCondition(context)) {
IncreaseLoopCount(context);
}
}
void LoopCountActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
branch_id_ = branch_id;
if (CheckLoopCountIncreaseCondition(context)) {
IncreaseLoopCount(context);
}
@ -138,7 +128,6 @@ void LoopCountActor::SendOutput(OpContext<DeviceTensor> *context) {
if (recorder_aid_ != nullptr) {
Async(*recorder_aid_, &RecorderActor::RecordOnStepEnd, context);
}
SendMemoryAllocReq(context);
}
@ -180,14 +169,8 @@ void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
bool LoopCountActor::CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
if (branch_id_ == kInvalidBranchID) {
return false;
}
if (branch_id_ >= SizeToInt(branch_id_to_input_controls_num_.size())) {
MS_LOG(ERROR) << "Branch id is invalid, id:" << branch_id_;
}
return input_op_controls_[sequential_num].size() == branch_id_to_input_controls_num_[branch_id_];
return input_op_controls_[sequential_num].size() == input_controls_num_;
}
} // namespace runtime
} // namespace mindspore

View File

@ -40,11 +40,10 @@ class LoopCountActor : public DebugAwareActor {
loop_count_(loop_count),
current_count_(0),
total_running_count_(0),
input_controls_num_(0),
memory_manager_aid_(memory_manager_aid),
debug_aid_(debug_aid),
recorder_aid_(recorder_aid) {
branch_id_to_input_controls_num_[kMainBranchID] = 0;
}
recorder_aid_(recorder_aid) {}
~LoopCountActor() override = default;
@ -63,11 +62,6 @@ class LoopCountActor : public DebugAwareActor {
// The callback after debug finished.
void OnDebugFinish(OpContext<DeviceTensor> *context) override;
// In control flow, there are multi-branch output situations. In this case, the gather actor will be numbered
// branch id, and the branch id will be sent to the loop count actor during operation. The interface is used
// to receive the branch id message.
void CollectBranchId(const int branch_id_, OpContext<DeviceTensor> *context);
private:
friend class GraphScheduler;
@ -84,7 +78,7 @@ class LoopCountActor : public DebugAwareActor {
// The dependent input controls number.
// In the multi-branch output scenario of the control flow, the control of each branch needs to be recorded
// separately with the branch id as the key. When the output has only one branch, the branch id is 0.
std::unordered_map<int, size_t> branch_id_to_input_controls_num_;
size_t input_controls_num_;
// The output controls contain the data source actors and the no input kernel actors and output actor.
std::vector<AID> data_source_aids_;
@ -98,10 +92,6 @@ class LoopCountActor : public DebugAwareActor {
// The id of recorder actor. Send message to it for clearing recorder info before loop count actor exits.
const AID *recorder_aid_;
// When the result of the graph is sent to the output actor, the gather actor of the graph needs
// to send branch_id to the output actor to determine the corresponding weight.
int branch_id_{kMainBranchID};
// The nodes need continuous memory, which must allocate in the begin of step running. The first bool of pair
// expresses the inputs of node need continuous memory, the second bool of pair expresses the outputs of node need
// continuous memory.

View File

@ -44,28 +44,25 @@ TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index,
void OutputActor::Init() {
// Set the number of actor running dependent messages.
if ((!need_loop_count_) && (device_tensor_store_keys_.size() == 1)) {
running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_[kMainBranchID].size());
if ((!need_loop_count_)) {
running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size());
}
}
void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (branch_id_ == kInvalidBranchID) {
MS_LOG(EXCEPTION) << "Invalid branch id for output actor.";
}
current_count_ = loop_count;
if (loop_count_ == current_count_) {
if (current_outputs_num_ + device_tensor_store_keys_[branch_id_].size() != outputs_num_) {
std::string error_info =
"The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) +
", the current outputs num: " + std::to_string(current_outputs_num_) +
", the device tensor store num: " + std::to_string(device_tensor_store_keys_[branch_id_].size());
if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) {
std::string error_info = "The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) +
", the current outputs num: " + std::to_string(current_outputs_num_) +
", the device tensor store num: " + std::to_string(device_tensor_store_keys_.size());
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// Because device tensor store can't send data, so fetch the output result of device tensor store in running end.
for (const auto &device_tensor_store_key : device_tensor_store_keys_[branch_id_]) {
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
if (device_tensor_store_key.first >= outputs_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range.");
}
@ -108,16 +105,10 @@ void OutputActor::UpdateOutputDeviceAddress() {
output_nodes_.resize(outputs_num_);
}
void OutputActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
branch_id_ = branch_id;
}
void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(output_node);
MS_EXCEPTION_IF_NULL(context);
// Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".
if (loop_count_ - current_count_ != 1) {
return;
@ -132,7 +123,7 @@ void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_ind
// Save the output nodes to clear the device tensor in the running end.
output_nodes_[output_position] = KernelWithIndex(output_node, output_index);
// There is no loop count actor in step mode, need trigger call CollectLoopCount to replace old output device tensors.
if (!need_loop_count_ && (current_outputs_num_ + device_tensor_store_keys_[branch_id_].size() == outputs_num_)) {
if (!need_loop_count_ && (current_outputs_num_ + device_tensor_store_keys_.size() == outputs_num_)) {
CollectLoopCount(++current_count_, context);
}
}

View File

@ -46,12 +46,10 @@ class OutputActor : public OpActor<DeviceTensor> {
outputs_num_(outputs_num),
current_outputs_num_(0),
need_loop_count_(need_loop_count),
branch_id_(kMainBranchID),
running_dependent_msg_num_(1) {
outputs_.resize(outputs_num);
output_nodes_.resize(outputs_num);
device_contexts_.resize(outputs_num);
device_tensor_store_keys_[kMainBranchID] = std::vector<std::pair<size_t, AnfNodePtr>>();
}
~OutputActor() override = default;
@ -65,8 +63,6 @@ class OutputActor : public OpActor<DeviceTensor> {
void CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
OpContext<DeviceTensor> *context);
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context);
// The graph output need be set new device address every step or loop, to avoid that the device address
// context of tensor be rewritten in the next step or next loop.
void UpdateOutputDeviceAddress();
@ -88,16 +84,11 @@ class OutputActor : public OpActor<DeviceTensor> {
size_t outputs_num_;
size_t current_outputs_num_;
bool need_loop_count_;
int branch_id_;
// The dependent messages number of actor running.
int running_dependent_msg_num_;
// Pair<branch_id, <index, node>> points to the dependent device tensor store, branch_id is the output branch id.
// In general, the branch id is 0, which means there is only one output branch in the actor set. When there are
// multiple possible output branches in the actor set, different branch ids correspond to their own related nodes.
// The index is the position of node in the output, node is the key of the device tensor store.
std::unordered_map<size_t, std::vector<std::pair<size_t, AnfNodePtr>>> device_tensor_store_keys_;
std::vector<std::pair<size_t, AnfNodePtr>> device_tensor_store_keys_;
};
using OutputActorPtr = std::shared_ptr<OutputActor>;

View File

@ -16,6 +16,7 @@
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/gather_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "mindrt/include/async/async.h"
#include "abstract/utils.h"
@ -39,10 +40,10 @@ void SwitchActor::Init() {
void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
input_op_datas_[sequential_num].emplace_back(input_data);
const auto &sequential_num = context->sequential_num_;
auto &input_datas = input_data_[sequential_num];
input_datas[input_data->index_].push(input_data->data_);
// When all the inputs are collected, then allocate memory and callback launch.
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
EraseInput(context);
@ -50,14 +51,38 @@ void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
}
}
void SwitchActor::Initialize() {
void SwitchActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
if (input_controls_[sequential_num].find(input_control) == input_controls_[sequential_num].end()) {
input_controls_[sequential_num][input_control] = 0;
}
input_controls_[sequential_num][input_control]++;
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
EraseInput(context);
SendOutput(context);
}
}
void SwitchActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_branch_ids_[sequential_num].push(branch_id);
}
void SwitchActor::Initialize(const ControlNodeParserPtr &parser) {
std::vector<AnfNodePtr> inputs = node_->inputs();
if (IsPrimitive(inputs[0], prim::kPrimSwitch)) {
InitSwitch();
} else if (IsPrimitive(inputs[0], prim::kPrimReturn)) {
InitReturn(parser);
} else {
InitSwitchLayer();
}
backend_parameters_.resize(input_nodes_.size());
}
void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
@ -88,6 +113,23 @@ void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
}
}
void SwitchActor::InitVectorSize(const size_t num) {
branch_inputs_pos_.resize(num);
branch_func_graph_.resize(num);
output_branch_arrows_.resize(num);
output_branch_result_arrows_.resize(num);
output_branch_control_arrows_.resize(num);
output_branch_branch_arrows_.resize(num);
}
void SwitchActor::InitReturn(const ControlNodeParserPtr &parser) {
const auto &func_graph = node_->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
const auto &call_num = parser->GetCallNumByFuncGraph(func_graph);
InitVectorSize(call_num);
AddCommonInput(func_graph->output());
}
void SwitchActor::InitSwitch() {
// The inputs of the switch node:
// [0] ValueNode<Primitive> kSwitch.
@ -99,14 +141,10 @@ void SwitchActor::InitSwitch() {
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is not equal 4";
}
branch_total_inputs_.resize(kSwitchPartialNum);
branch_inputs_pos_.resize(kSwitchPartialNum);
branch_func_graph_.resize(kSwitchPartialNum);
output_branch_arrows_.resize(kSwitchPartialNum);
output_branch_result_arrows_.resize(kSwitchPartialNum);
output_branch_control_arrows_.resize(kSwitchPartialNum);
InitVectorSize(kSwitchPartialNum);
input_nodes_.push_back(inputs[kSwitchCondPos]);
const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0);
input_nodes_.push_back(cond_node);
input_datas_num_++;
// Init the two branches of switch node.
InitPartial(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
@ -123,16 +161,13 @@ void SwitchActor::InitSwitchLayer() {
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitchLayer->name() << " is not equal 3";
}
input_nodes_.push_back(inputs[kSwitchLayerCondPos]);
const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchLayerCondPos], 0);
input_nodes_.push_back(cond_node);
input_datas_num_++;
// The second input of SwitchLayer is maketuple node, which includes all branches.
auto branch_nodes = inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
branch_total_inputs_.resize(branch_nodes.size() - 1);
branch_inputs_pos_.resize(branch_nodes.size() - 1);
branch_func_graph_.resize(branch_nodes.size() - 1);
output_branch_arrows_.resize(branch_nodes.size() - 1);
output_branch_result_arrows_.resize(branch_nodes.size() - 1);
output_branch_control_arrows_.resize(branch_nodes.size() - 1);
InitVectorSize(branch_nodes.size() - 1);
// Parse all branches.
for (size_t i = 1; i < branch_nodes.size(); ++i) {
@ -151,41 +186,92 @@ void SwitchActor::AddCommonInput(const AnfNodePtr &node) {
}
size_t SwitchActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
const auto &iter = find(input_nodes_.begin(), input_nodes_.end(), data_node);
const auto data_node_with_index = AnfAlgo::VisitKernelWithReturnType(data_node, 0);
const auto &iter = find(input_nodes_.begin(), input_nodes_.end(), data_node_with_index);
if (iter == input_nodes_.end()) {
MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope()
<< " is not exist in gather actor:" << GetAID();
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node)
<< " is not exist in switch actor:" << GetAID();
}
return iter - input_nodes_.begin();
}
void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
branch_total_inputs_[branch].push_back(node);
void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t branch) {
const auto &node = node_with_index.first;
if (node->isa<ValueNode>() && (!HasAbstractMonad(node))) {
// Add weight and value node.
if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && HasAbstractRef(node)) || node->isa<ValueNode>()) {
const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
if (iter != input_nodes_.end()) {
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
return;
}
device_tensor_store_keys_.push_back({input_nodes_.size(), node.get()});
branch_inputs_pos_[branch].push_back(input_nodes_.size());
input_nodes_.push_back(node);
input_nodes_.push_back(node_with_index);
return;
}
// Switch actor only receives parameter, updatestate node output is U, need to be skipped.
if (IsPersistentDeviceTensor(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
// Output of updatestate node is U, need to be skipped.
if (HasAbstractRef(node)) {
return;
}
auto iter = find(input_nodes_.begin(), input_nodes_.end(), node);
// Add parameter.
auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
if (iter == input_nodes_.end()) {
branch_inputs_pos_[branch].push_back(input_nodes_.size());
input_nodes_.push_back(node);
input_nodes_.push_back(node_with_index);
++input_datas_num_;
} else {
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
}
}
size_t SwitchActor::GetIndex() {
void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState) || HasAbstractMonad(node)) {
return;
}
const auto &real_input = AnfAlgo::VisitKernelWithReturnType(node, 0);
if (AnfAlgo::CheckPrimitiveType(real_input.first, prim::kPrimMakeTuple)) {
const auto &inputs = real_input.first->cast<CNodePtr>()->inputs();
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
AddInput(inputs[i], branch);
}
} else if (IsCallNode(real_input.first)) {
std::vector<AnfNodePtr> call_nodes;
const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes);
if (call_output_num <= 0) {
MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first);
}
for (size_t i = 0; i < call_output_num; ++i) {
AddInput({real_input.first, i}, branch);
}
} else {
AddInput(real_input, branch);
}
}
size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
if (need_branch_id_input_) {
if (input_branch_ids_.find(context->sequential_num_) == input_branch_ids_.end() ||
input_branch_ids_[context->sequential_num_].empty()) {
MS_LOG(EXCEPTION) << "Invalid branch id for actor:" << GetAID();
}
size_t branch_id = input_branch_ids_[context->sequential_num_].top();
input_branch_ids_[context->sequential_num_].pop();
if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) {
MS_LOG(EXCEPTION) << "Invalid branch id for switch actor:" << GetAID() << " branch id:" << branch_id;
}
return branch_id_to_index_[branch_id];
}
DeviceTensor *device_tensor = input_device_tensors_[0];
if (device_tensor == nullptr) {
MS_LOG(EXCEPTION) << "Index of switch actor is empty:" << GetAID();
}
auto inputs = node_->inputs();
TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0);
size_t size = abstract::TypeIdSize(type_id);
@ -219,28 +305,46 @@ size_t SwitchActor::GetIndex() {
bool SwitchActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter == input_op_datas_.end()) {
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter == input_data_.end()) {
return false;
}
if (data_iter->second.size() != input_datas_num_) {
return false;
}
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
[](const auto &input_stack) { return input_stack.second.empty(); })) {
return false;
}
}
if (input_controls_num_ != 0) {
auto data_iter = input_controls_.find(context->sequential_num_);
if (data_iter == input_controls_.end()) {
return false;
}
if (data_iter->second.size() != input_controls_num_) {
return false;
}
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
[](const auto &input_stack) { return input_stack.second == 0; })) {
return false;
}
}
return true;
}
void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
input_device_tensors_.resize(input_nodes_.size());
auto data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end()) {
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter != input_data_.end()) {
for (auto &input_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(input_data);
input_device_tensors_[input_data->index_] = input_data->data_;
input_device_tensors_[input_data.first] = input_data.second.top();
input_data.second.pop();
}
}
data_iter->second.clear();
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
auto device_tensor =
@ -253,15 +357,28 @@ void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
}
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
}
auto control_iter = input_controls_.find(context->sequential_num_);
if (control_iter != input_controls_.end()) {
for_each(control_iter->second.begin(), control_iter->second.end(),
[](auto &input_control) { input_control.second--; });
}
}
void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto index = GetIndex();
auto index = GetIndex(context);
if (index >= output_branch_arrows_.size()) {
MS_LOG(EXCEPTION) << "Switch actor invalid index:" << index;
}
if (local_branch_id_ >= 0) {
const auto &branch_arrows = output_branch_branch_arrows_[index];
for (const auto &branch_arrow : branch_arrows) {
Async(branch_arrow, &GatherActor::CollectBranchId, local_branch_id_, context);
}
}
auto &output_branch_arrow = output_branch_arrows_[index];
auto &output_data = output_data_[index];
for (size_t i = 0; i < output_branch_arrow.size(); ++i) {
@ -270,6 +387,7 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(data_arrow);
MS_EXCEPTION_IF_NULL(data);
data->data_ = input_device_tensors_[data_arrow->from_output_index_];
Async(data_arrow->to_op_id_, &OpActor::RunOpData, data.get(), context);
}
@ -279,7 +397,7 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
auto &result_arrow = output_branch_result_arrow[i];
MS_EXCEPTION_IF_NULL(result_arrow);
size_t from_index = result_arrow->from_output_index_;
for (const auto &backend_node : front_to_backend_parameter_[from_index]) {
for (const auto &backend_node : backend_parameters_[from_index]) {
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() ==
input_device_tensors_[from_index]) {
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
@ -298,8 +416,10 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
void SwitchActor::EraseInput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto ret = input_op_datas_.erase(context->sequential_num_);
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter != input_data_.end() && std::all_of(data_iter->second.begin(), data_iter->second.end(),
[](const auto &input_data) { return input_data.second.empty(); })) {
auto ret = input_data_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input data failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
@ -307,10 +427,15 @@ void SwitchActor::EraseInput(OpContext<DeviceTensor> *context) {
}
if (input_controls_num_ != 0) {
auto ret = input_op_controls_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input controls failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
auto control_iter = input_controls_.find(context->sequential_num_);
if (control_iter != input_controls_.end() &&
std::all_of(control_iter->second.begin(), control_iter->second.end(),
[](const auto &input_control) { return input_control.second == 0; })) {
auto ret = input_controls_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input control failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
}
}
@ -319,37 +444,20 @@ void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context);
}
void SwitchActor::FetchInputNode(const std::vector<AnfNodePtr> &origin_parameters_order,
const FrontToBackendNodeWithContext &front_to_backend_parameters,
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel) {
front_to_backend_parameter_.resize(input_nodes_.size());
void SwitchActor::FetchInputNode(const ControlNodeParserPtr &parser) {
for (size_t i = 0; i < input_nodes_.size(); ++i) {
const auto &input_node = input_nodes_[i];
if (input_node->isa<ValueNode>()) {
front_to_backend_parameter_[i].push_back({input_node, 0});
} else if (input_node->isa<Parameter>()) {
if (front_to_backend_parameters.find(input_node) != front_to_backend_parameters.end()) {
const auto backend_node = front_to_backend_parameters.at(input_node).first;
front_to_backend_parameter_[i].push_back({backend_node, 0});
}
} else if (input_node->isa<CNode>()) {
if (IsCallNode(input_node)) {
const auto func_graphs = FetchFuncGraphbyCallNode(input_node->cast<CNodePtr>());
for (const auto func_graph : func_graphs) {
if (func_graph->output()->isa<ValueNode>()) {
front_to_backend_parameter_[i].push_back({func_graph->output(), 0});
}
}
} else {
const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
if (front_to_backend_kernel.find(input_node) != front_to_backend_kernel.end()) {
front_to_backend_parameter_[i].emplace_back(kernel_with_index);
}
}
const auto &input_node = input_nodes_[i].first;
if (!HasAbstractRef(input_node)) {
backend_parameters_[i] = parser->FetchBackendInputNodeByFrontNode(input_node);
continue;
}
const auto &backend_weight = parser->FetchBackendNodebyWeightNode(input_node);
if (backend_weight == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight node:" << AnfAlgo::GetNodeDebugString(input_node);
}
backend_parameters_[i].push_back({backend_weight, 0});
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -58,16 +58,25 @@ constexpr size_t kMakeTupleInputStartPos = 1;
// 5. Free Memory
class SwitchActor : public SwitchActorBase<DeviceTensor> {
public:
SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node)
: SwitchActorBase(name), device_context_(device_context), node_(node) {}
SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node, const int branch_id,
const bool need_branch_id_input)
: SwitchActorBase(name),
device_context_(device_context),
node_(node),
local_branch_id_(branch_id),
need_branch_id_input_(need_branch_id_input) {}
~SwitchActor() override = default;
void Init() override;
// The switch actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context);
// The switch actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context);
// The switch actor run when receive the input branch id.
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context);
// Initialize the input and output information of the switch actor According to node_.
void Initialize();
void Initialize(const ControlNodeParserPtr &parser);
// Add input for all branches.
void AddCommonInput(const AnfNodePtr &node);
// Fetch the input position of the data node.
@ -79,11 +88,16 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
void InitPartial(const AnfNodePtr &node, const size_t branch_id);
void InitSwitch();
void InitSwitchLayer();
// In control flow, the output of each subgraph is connected to a switch actor, and the switch actor is
// initialized with the return node of the subgraph.
void InitReturn(const ControlNodeParserPtr &parser);
// Initialize the size of the vector members.
void InitVectorSize(const size_t num);
// Get index from DeviceTensor.
size_t GetIndex();
size_t GetIndex(OpContext<DeviceTensor> *context);
// Add input for the branch.
void AddInput(const AnfNodePtr &node, size_t branch);
void AddInput(const KernelWithIndex node_with_index, const size_t branch);
// Check whether satisfy the condition for send outputs.
bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const;
@ -95,12 +109,10 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
void SendMemoryFreeReq(OpContext<DeviceTensor> *context);
// Collect all the backend inputs of switch actor.
void FetchInputNode(const std::vector<AnfNodePtr> &origin_parameters_order,
const FrontToBackendNodeWithContext &front_to_backend_parameters,
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel);
// All inputs of the switch actor, excluding weight and tensor.
void FetchInputNode(const ControlNodeParserPtr &parser);
// All inputs of the switch actor, include weight and tensor.
// Used to receive input data, the first input is the condition of switch.
std::vector<AnfNodePtr> input_nodes_;
std::vector<KernelWithIndex> input_nodes_;
// The position of the branch output in the input_nodes_.
std::vector<std::vector<size_t>> branch_inputs_pos_;
@ -126,9 +138,9 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
// When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor,
// so all the nodes that may send the device tensor to switch actor are recorded.
std::vector<std::vector<KernelWithIndex>> front_to_backend_parameter_;
std::vector<std::vector<KernelWithIndex>> backend_parameters_;
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
std::vector<FuncGraphPtr> branch_func_graph_;
std::unordered_map<int, size_t> branch_id_to_index_;
@ -148,8 +160,12 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
// The dependent input controls number.
size_t input_controls_num_{0};
CNodePtr node_;
// The branch id corresponding to the funcgraph to which the gather actor belongs.
int local_branch_id_;
size_t input_branch_id_num_;
// Whether it needs to accept the branch id. When the switch actor is the output of the subgraph, it needs to receive
// branch id sent by the gather actor of subgraph, which will be true at this time.
bool need_branch_id_input_;
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;

View File

@ -547,6 +547,8 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) {
void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph) {
root_func_graph_ = root_graph;
root_graph_parameters_ = root_graph->parameters();
CreateBranchIDForFuncGraph(control_nodes);
@ -598,6 +600,27 @@ bool ControlNodeParser::IsCallInputKernelGraph(const KernelGraphPtr &graph) {
return true;
}
bool ControlNodeParser::IsKernelInRootFuncGraph(const AnfNodePtr &kernel) {
if (kernel == nullptr) {
return true;
}
const auto &graph = kernel->func_graph();
if (kernel != nullptr && graph != nullptr) {
const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
if (kernel_graph == nullptr) {
return true;
}
const auto func_graph = kernel_graph->GetFuncGraph();
if (func_graph != nullptr && func_graph != root_func_graph_) {
return false;
}
}
return true;
}
size_t ControlNodeParser::GetCallNumByFuncGraph(const FuncGraphPtr &func_graph) {
if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) {
MS_LOG(EXCEPTION) << "Invalid funcgraph:" << func_graph->ToString();
@ -622,6 +645,21 @@ DeviceContext *ControlNodeParser::GetFrontValueNodeDeviceContext(const AnfNodePt
return nullptr;
}
AnfNodePtr ControlNodeParser::FetchBackendNodebyWeightNode(const AnfNodePtr &node) {
for (const auto &host_parameter_to_weight : host_parameter_to_weights_) {
for (const auto &front_weight : host_parameter_to_weight.second) {
if (front_weight == node) {
const auto &iter = front_to_backend_parameters_.find(front_weight);
if (iter != front_to_backend_parameters_.end()) {
return iter->second.first;
}
}
}
}
return nullptr;
}
void ControlNodeParser::FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node,
std::vector<AnfNodePtr> *value_nodes) {
const auto &cnode = switch_node->cast<CNodePtr>();
@ -928,6 +966,40 @@ std::vector<AnfNodePtr> FetchInputParameterbyControlNode(const AnfNodePtr &node,
return parameters;
}
std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph) {
std::vector<AnfNodePtr> parameters;
const auto &graph_parameters = graph->input_nodes();
for (const auto &graph_parameter : graph_parameters) {
const auto &front_node = graph->GetFrontAnfByBackendAnf(graph_parameter);
if (front_node != nullptr) {
parameters.emplace_back(front_node);
continue;
}
const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(graph_parameter);
if (front_node_with_index.first == nullptr) {
MS_LOG(WARNING) << "Invalid parameter of kernel graph, parameter :"
<< AnfAlgo::GetNodeDebugString(graph_parameter);
continue;
}
if (HasAbstractRef(AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, 0).first) ||
HasAbstractMonad(front_node_with_index.first)) {
continue;
}
if (AnfAlgo::CheckPrimitiveType(front_node_with_index.first, prim::kPrimMakeTuple)) {
const auto &sub_parameters = FetchInputsByMakeTuple(front_node_with_index.first);
parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
continue;
}
parameters.emplace_back(front_node_with_index.first);
}
return parameters;
}
void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes) {
@ -939,7 +1011,7 @@ void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGra
for (size_t i = 0; i < graphs.size(); ++i) {
const auto &graph = graphs[i];
auto device_context = device_contexts[i];
for (const auto &parameter : graph->parameters()) {
for (const auto &parameter : graph->input_nodes()) {
auto front_node = graph->GetFrontAnfByBackendAnf(parameter);
if (front_node != nullptr && front_node->isa<Parameter>() &&

View File

@ -71,8 +71,8 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
// Find all funcgraphs that the call node will call.
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node);
// Recursive interface, get all input of make tuple node.
std::vector<AnfNodePtr> FetchInputsByMakeTuple(const AnfNodePtr &node);
// Get parameters in kernel graph.
std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph);
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
class ControlNodeParser {
@ -107,6 +107,15 @@ class ControlNodeParser {
// Check whether there is a call node in the front input nodes of the kernel graph.
bool IsCallInputKernelGraph(const KernelGraphPtr &graph);
// Check whether the kernel actor belongs to the root graph.
// In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output should
// be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be sent to the
// output switch actor.
bool IsKernelInRootFuncGraph(const AnfNodePtr &kernel);
// Get the backend node corresponding to the weight node in the subgraph.
AnfNodePtr FetchBackendNodebyWeightNode(const AnfNodePtr &node);
private:
friend class GraphScheduler;
@ -195,10 +204,12 @@ class ControlNodeParser {
std::vector<AnfNodePtr> control_node_parameters_;
// The number of calls to func_graph.
std::unordered_map<FuncGraphPtr, size_t> func_graph_to_call_num_;
// The kernel graph of call exists in the front-end input node.
std::unordered_map<KernelGraphPtr, DeviceContext *> call_input_kernel_graphs_;
// The kernel graph of call exists in the front input node.
// In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a gather actor
// is created for kernel graph which has a call input.
std::unordered_map<KernelGraphPtr, DeviceContext *> call_input_kernel_graphs_;
// Root funcgraph and its parameters.
FuncGraphPtr root_func_graph_;
std::vector<AnfNodePtr> root_graph_parameters_;
};

File diff suppressed because it is too large Load Diff

View File

@ -46,7 +46,7 @@ using mindspore::session::KernelWithIndex;
// Position of kernel with index, the value pair<branch_id, vector<pos>> means the branch id of the kernel and the pos
// of the kernel. Generally, there is only one branch, and the branch id is 0 at this time. In control flow, there are
// multiple branch scenarios, and pos represents the position of the kernel in the branch.
using KernelMapPosition = std::map<KernelWithIndex, std::pair<int, std::vector<size_t>>, session::KernelWithIndexCmp>;
using KernelMapPosition = std::map<KernelWithIndex, std::vector<size_t>, session::KernelWithIndexCmp>;
using ActorInfo = std::string;
// The second element of pair represents the output index of op actor corresponding to the graph output node.
@ -212,7 +212,8 @@ class GraphScheduler {
KernelWithIndex to_kernel_with_input_idx);
// 2. The processing of linking control arrows.
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set);
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
const ControlNodeParserPtr &parser);
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node);
// The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
void LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node);
@ -227,20 +228,24 @@ class GraphScheduler {
// 4. The processing of control flow linking.
void LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *actor_set);
void LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx);
void LinkDataArrowForGatherActor(GatherActor *from_actor, const AnfNodePtr &front_node, KernelActor *to_actor,
const size_t to_index);
void LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, SwitchActor *actor);
// Connect the input of the actor.
void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const AnfNodePtr &input_node,
OpActor<DeviceTensor> *to_actor, const size_t to_index);
void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const KernelWithIndex &input_node,
const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *to_actor,
const size_t to_index);
// When the input of the actor is a call node, the output of the funcgraph called by the call node needs to be
// connected.
void LinkDataArrowByCallInput(const GraphCompilerInfo &graph_compiler_info, const AnfNodePtr &call_node,
OpActor<DeviceTensor> *to_actor, const size_t to_index);
void LinkDataArrowForSwitchActor(SwitchActor *from_actor, KernelActor *to_actor, const size_t to_index);
void LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors, LoopCountActor *to_actor,
const std::vector<KernelGraphPtr> &graphs);
void LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index, const ControlNodeParserPtr &parser,
const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *to_actor,
const size_t to_index);
void LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index, OpActor<DeviceTensor> *to_actor,
const size_t to_index, const size_t branch_index = SIZE_MAX);
void LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors,
std::vector<KernelActorPtr> *kernel_actors, LoopCountActor *to_actor,
const std::vector<KernelGraphPtr> &graphs, const ControlNodeParserPtr &parser);
void LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *switch_actors, LoopCountActor *to_actor,
const KernelMapPosition &origin_outputs_order);
// In control flow, there are scenarios where there are multi-branch outputs, and the gather actor needs to

View File

@ -712,26 +712,25 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
auto parser = std::make_shared<ControlNodeParser>();
parser->Parse(control_nodes_, graphs, device_contexts, root_graph);
// Get all the outputs. In control flow, there may be multiple branch output.
runtime::KernelMapPosition outputs_order;
size_t outputs_num = 0;
const auto &all_branch_output = parser->FetchAllBranchOutputs(root_graph);
for (int j = 0; j < SizeToInt(all_branch_output.size()); ++j) {
// In general, there is only one output branch, and the branch id is 0 at this time. In the control flow,
// there are multi-branch output scenarios. Different branches may have different weight nodes. When output
// actor run, the corresponding weight node needs to be obtained according to different branches. Therefore,
// the branch of the output nodes needs to be recorded.
const int branch_id = ((all_branch_output.size() == 1 ? runtime::kMainBranchID : (j + runtime::kSubBranchStartID)));
const auto &branch_output = all_branch_output[j];
size_t position = 0;
auto outputs = AnfAlgo::GetAllOutputWithIndex(branch_output);
outputs_num = outputs.size();
for (const auto &output : outputs) {
if (outputs_order.count(output) == 0) {
outputs_order[output] = {branch_id, {position++}};
} else {
outputs_order[output].second.emplace_back(position++);
}
const auto &root_output =
AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
size_t position = 0;
auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
if (runtime::IsCallNode(root_output)) {
std::vector<AnfNodePtr> call_nodes;
size_t call_output_num = runtime::FetchOutputSizebyCallNode(root_output, &call_nodes);
for (size_t i = 0; i < call_output_num; ++i) {
outputs.push_back({root_output, i});
}
}
outputs_num = outputs.size();
for (const auto &output : outputs) {
if (outputs_order.count(output) == 0) {
outputs_order[output] = {position++};
} else {
outputs_order[output].emplace_back(position++);
}
}
@ -759,9 +758,9 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
for (const auto &output : outputs) {
if (outputs_order.count(output) == 0) {
outputs_order[output] = {runtime::kMainBranchID, {position++}};
outputs_order[output] = {position++};
} else {
outputs_order[output].second.emplace_back(position++);
outputs_order[output].emplace_back(position++);
}
}
}