!18941 Fix mulit call in control flow
Merge pull request !18941 from gaoyong10/new_runtime13
This commit is contained in:
commit
29e7da4c3e
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "runtime/framework/actor/gather_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "runtime/framework/actor/memory_manager_actor.h"
|
||||
#include "runtime/framework/actor/loop_count_actor.h"
|
||||
#include "mindrt/include/async/async.h"
|
||||
|
@ -44,7 +45,7 @@ void GatherActor::Init() {
|
|||
size_t GatherActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
|
||||
const auto &iter = find(data_nodes_.begin(), data_nodes_.end(), data_node);
|
||||
if (iter == data_nodes_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope()
|
||||
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node)
|
||||
<< " is not exist in gather actor:" << GetAID();
|
||||
}
|
||||
return iter - data_nodes_.begin();
|
||||
|
@ -52,9 +53,8 @@ size_t GatherActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
|
|||
|
||||
void GatherActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
auto sequential_num = context->sequential_num_;
|
||||
input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
input_data_[sequential_num][input_data->index_].push(input_data->data_);
|
||||
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
|
@ -63,6 +63,29 @@ void GatherActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
|
|||
}
|
||||
}
|
||||
|
||||
void GatherActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
}
|
||||
|
||||
void GatherActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
input_branch_ids_[sequential_num] = branch_id;
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
}
|
||||
|
||||
void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser) {
|
||||
for (const auto &input : func_graph->get_inputs()) {
|
||||
// Monad input would not send to gather actor.
|
||||
|
@ -76,20 +99,20 @@ void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph, const Co
|
|||
|
||||
void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
// Branch arrow and result arrow must be executed before the data arrow and control arrow, otherwise the output
|
||||
// actor may receive the loop count message first and cause the output to be abnormal.
|
||||
if (branch_id_ > kInvalidBranchID) {
|
||||
Async(loop_count_aid_, &LoopCountActor::CollectBranchId, branch_id_, context);
|
||||
Async(output_aid_, &OutputActor::CollectBranchId, branch_id_, context);
|
||||
// Send output branch id.
|
||||
if (find(output_branch_arrows_.begin(), output_branch_arrows_.end(), switch_aid_) != output_branch_arrows_.end()) {
|
||||
int branch_id = input_branch_id_;
|
||||
Async(switch_aid_, &SwitchActor::CollectBranchId, branch_id, context);
|
||||
}
|
||||
if (find(output_branch_arrows_.begin(), output_branch_arrows_.end(), gather_aid_) != output_branch_arrows_.end()) {
|
||||
Async(gather_aid_, &GatherActor::CollectBranchId, local_branch_id_, context);
|
||||
}
|
||||
|
||||
// Send graph output result.
|
||||
// Send output result.
|
||||
for (const auto &result_arrow : output_result_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
size_t from_index = result_arrow->from_output_index_;
|
||||
const auto &front_node = data_nodes_[from_index];
|
||||
|
||||
for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) {
|
||||
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() ==
|
||||
input_device_tensors_[from_index]) {
|
||||
|
@ -115,15 +138,28 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
|||
|
||||
void GatherActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
auto data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter != input_op_datas_.end()) {
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter != input_data_.end()) {
|
||||
for (auto &input_data : data_iter->second) {
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
input_device_tensors_[input_data->index_] = input_data->data_;
|
||||
input_device_tensors_[input_data.first] = input_data.second.top();
|
||||
input_data.second.pop();
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
|
||||
const auto &device_context = device_contexts_[device_tensor_store_key.first];
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
auto device_tensor =
|
||||
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context->GetDeviceAddressType());
|
||||
if (device_tensor == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
|
||||
", device type:" + std::to_string(static_cast<int>(device_context->GetDeviceAddressType()));
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
|
||||
const auto &data = input_device_tensors_[i];
|
||||
for (auto &output_data : output_data_by_output_index_[i]) {
|
||||
|
@ -131,20 +167,31 @@ void GatherActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
|
|||
output_data->data_ = data;
|
||||
}
|
||||
}
|
||||
|
||||
if (need_branch_id_input_) {
|
||||
input_branch_id_ = input_branch_ids_[context->sequential_num_];
|
||||
}
|
||||
}
|
||||
|
||||
bool GatherActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
// Fetch input data.
|
||||
if (input_datas_num_ != 0) {
|
||||
auto data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter == input_op_datas_.end()) {
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter == input_data_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_datas_num_) {
|
||||
if (data_iter->second.size() + device_tensor_store_keys_.size() != input_datas_num_) {
|
||||
return false;
|
||||
}
|
||||
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[](const auto &input_stack) { return input_stack.second.empty(); })) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch input control.
|
||||
if (input_controls_num_ != 0) {
|
||||
auto control_iter = input_op_controls_.find(context->sequential_num_);
|
||||
if (control_iter == input_op_controls_.end()) {
|
||||
|
@ -154,19 +201,32 @@ bool GatherActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch input branch id.
|
||||
if (need_branch_id_input_) {
|
||||
auto branch_id_iter = input_branch_ids_.find(context->sequential_num_);
|
||||
if (branch_id_iter == input_branch_ids_.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void GatherActor::EraseInput(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto ret = input_op_datas_.erase(context->sequential_num_);
|
||||
|
||||
// Erase input data.
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter != input_data_.end() && std::all_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[](const auto &input_data) { return input_data.second.empty(); })) {
|
||||
auto ret = input_data_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input data failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
|
||||
// Erase input control.
|
||||
if (input_controls_num_ != 0) {
|
||||
auto ret = input_op_controls_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
|
@ -174,6 +234,15 @@ void GatherActor::EraseInput(OpContext<DeviceTensor> *context) {
|
|||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
|
||||
// Erase input branch id.
|
||||
if (need_branch_id_input_) {
|
||||
auto ret = input_branch_ids_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input branch id failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <stack>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
|
@ -36,20 +37,37 @@ namespace runtime {
|
|||
|
||||
constexpr size_t kReturnInputPos = 1;
|
||||
|
||||
// Gather actor is the entrance of sub funcgraph. Graph input is sent to it and sent to other actors by gather actor.
|
||||
// Gather actor is used in three places:
|
||||
// 1. Entrance of sub funcgraph
|
||||
// 2. call node which input0 is a funcgraph
|
||||
// 3. There is some call nodes in the inputs of kernel graph.
|
||||
// Gather actor will be used in the control flow. When the subgraph is called, the real parameters need to be put
|
||||
// together and sent to the subgraph. At the same time, the entry of the subgraph needs to accept input data.
|
||||
// Special in recursion, general inputs and call inputs of the kernel graph are used in stack mode, it needs to be
|
||||
// collected at the entrance of the kernel graph.
|
||||
class GatherActor : public OpActor<DeviceTensor> {
|
||||
public:
|
||||
GatherActor(const std::string &name, const std::vector<AnfNodePtr> ¶meters, const AID loop_count_aid,
|
||||
const AID output_aid)
|
||||
: OpActor(name), data_nodes_(parameters), loop_count_aid_(loop_count_aid), output_aid_(output_aid) {}
|
||||
GatherActor(const std::string &name, const std::vector<AnfNodePtr> ¶meters, const bool need_branch_id_input,
|
||||
const AID switch_aid, const AID gather_aid, const int branch_id)
|
||||
: OpActor(name),
|
||||
data_nodes_(parameters),
|
||||
need_branch_id_input_(need_branch_id_input),
|
||||
switch_aid_(switch_aid),
|
||||
gather_aid_(gather_aid),
|
||||
local_branch_id_(branch_id) {
|
||||
device_contexts_.resize(parameters.size());
|
||||
}
|
||||
~GatherActor() override = default;
|
||||
|
||||
// Get the index of the parameter, the data_node needs to be the front node.
|
||||
size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;
|
||||
|
||||
// The kernel actor run when receive the input data.
|
||||
// The gather actor run when receive the input data.
|
||||
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override;
|
||||
|
||||
// The gather actor run when receive the input control.
|
||||
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
|
||||
// The gather actor run when receive the input branch id.
|
||||
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context);
|
||||
void Init() override;
|
||||
|
||||
private:
|
||||
|
@ -66,13 +84,33 @@ class GatherActor : public OpActor<DeviceTensor> {
|
|||
|
||||
// The device tensors for launch.
|
||||
std::vector<DeviceTensor *> input_device_tensors_;
|
||||
// The branch if for current step.
|
||||
int input_branch_id_;
|
||||
|
||||
DeviceContext *device_contexts_;
|
||||
// Input data.
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
|
||||
// Input branch ids is used to record the id corresponding receive from gather actor.
|
||||
// In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
|
||||
// places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send
|
||||
// its branch id to the gather actor of the subgraph. Then branch id will be sent by the gather actor to the
|
||||
// switch actor connected to the output.
|
||||
std::unordered_map<uuids::uuid *, int> input_branch_ids_;
|
||||
|
||||
// Output data.
|
||||
// Cache unique output data by output index to modify the output data effectively.
|
||||
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
std::vector<OpData<DeviceTensor> *> output_data_;
|
||||
|
||||
// Output arrows.
|
||||
std::vector<DataArrowPtr> output_result_arrows_;
|
||||
std::vector<AID> output_branch_arrows_;
|
||||
|
||||
// Parameters of sub funcgraph, which is the front node.
|
||||
std::vector<AnfNodePtr> data_nodes_;
|
||||
std::vector<DeviceContext *> device_contexts_;
|
||||
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
|
||||
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
|
||||
|
||||
// When the output is a parameter of the subgraph, the gather actor needs to send the anfnode to the output actor,
|
||||
// so all the nodes that may send the device tensor to gather actor are recorded. When the anfnode needs to be sent
|
||||
|
@ -83,18 +121,19 @@ class GatherActor : public OpActor<DeviceTensor> {
|
|||
size_t input_datas_num_{0};
|
||||
// The dependent input controls number.
|
||||
size_t input_controls_num_{0};
|
||||
// Whether it needs to accept the branch id. When the gather actor is the input of the subgraph, it needs to receive
|
||||
// branch id sent by the subgraph caller, which will be true at this time.
|
||||
bool need_branch_id_input_;
|
||||
|
||||
const AID loop_count_aid_;
|
||||
const AID output_aid_;
|
||||
// Actor id that needs to send the branch id to it.
|
||||
// When the actor is corresponding to call node, the branch id needs to be sent to the input gather actor and output
|
||||
// switch actor of the called funcgraph. When the actor is the entrance of the funcgraph, the gather actor id is
|
||||
// empty, just need to send branch id to its output switch actor.
|
||||
const AID switch_aid_;
|
||||
const AID gather_aid_;
|
||||
|
||||
// Cache unique output data by output index to modify the output data effectively.
|
||||
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
std::vector<OpData<DeviceTensor> *> output_data_;
|
||||
|
||||
// When the result of the graph is sent to the output actor, the gather actor of the graph needs
|
||||
// to send branch_id to the output actor to determine the corresponding weight.
|
||||
int branch_id_{kInvalidBranchID};
|
||||
// The branch id corresponding to the funcgraph to which the gather actor belongs.
|
||||
int local_branch_id_;
|
||||
};
|
||||
|
||||
using GatherActorPtr = std::shared_ptr<GatherActor>;
|
||||
|
|
|
@ -77,6 +77,9 @@ void KernelActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
if (input_data->data_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Input data of actor:" << GetAID() << " num:" << input_data->index_ << " is empty";
|
||||
}
|
||||
// When all the inputs are collected, then allocate memory and callback launch.
|
||||
if (CheckLaunchCondition(context)) {
|
||||
// Infer kernel shape and update abstract info for dynamic shape kernel.
|
||||
|
@ -245,7 +248,7 @@ void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
|
|||
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
|
||||
if (device_tensor == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() +
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
|
||||
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
|
|
@ -86,16 +86,6 @@ void LoopCountActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *c
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
|
||||
if (CheckLoopCountIncreaseCondition(context)) {
|
||||
IncreaseLoopCount(context);
|
||||
}
|
||||
}
|
||||
|
||||
void LoopCountActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
branch_id_ = branch_id;
|
||||
|
||||
if (CheckLoopCountIncreaseCondition(context)) {
|
||||
IncreaseLoopCount(context);
|
||||
}
|
||||
|
@ -138,7 +128,6 @@ void LoopCountActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
if (recorder_aid_ != nullptr) {
|
||||
Async(*recorder_aid_, &RecorderActor::RecordOnStepEnd, context);
|
||||
}
|
||||
|
||||
SendMemoryAllocReq(context);
|
||||
}
|
||||
|
||||
|
@ -180,14 +169,8 @@ void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
|
|||
bool LoopCountActor::CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
if (branch_id_ == kInvalidBranchID) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (branch_id_ >= SizeToInt(branch_id_to_input_controls_num_.size())) {
|
||||
MS_LOG(ERROR) << "Branch id is invalid, id:" << branch_id_;
|
||||
}
|
||||
return input_op_controls_[sequential_num].size() == branch_id_to_input_controls_num_[branch_id_];
|
||||
return input_op_controls_[sequential_num].size() == input_controls_num_;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,11 +40,10 @@ class LoopCountActor : public DebugAwareActor {
|
|||
loop_count_(loop_count),
|
||||
current_count_(0),
|
||||
total_running_count_(0),
|
||||
input_controls_num_(0),
|
||||
memory_manager_aid_(memory_manager_aid),
|
||||
debug_aid_(debug_aid),
|
||||
recorder_aid_(recorder_aid) {
|
||||
branch_id_to_input_controls_num_[kMainBranchID] = 0;
|
||||
}
|
||||
recorder_aid_(recorder_aid) {}
|
||||
|
||||
~LoopCountActor() override = default;
|
||||
|
||||
|
@ -63,11 +62,6 @@ class LoopCountActor : public DebugAwareActor {
|
|||
// The callback after debug finished.
|
||||
void OnDebugFinish(OpContext<DeviceTensor> *context) override;
|
||||
|
||||
// In control flow, there are multi-branch output situations. In this case, the gather actor will be numbered
|
||||
// branch id, and the branch id will be sent to the loop count actor during operation. The interface is used
|
||||
// to receive the branch id message.
|
||||
void CollectBranchId(const int branch_id_, OpContext<DeviceTensor> *context);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
|
@ -84,7 +78,7 @@ class LoopCountActor : public DebugAwareActor {
|
|||
// The dependent input controls number.
|
||||
// In the multi-branch output scenario of the control flow, the control of each branch needs to be recorded
|
||||
// separately with the branch id as the key. When the output has only one branch, the branch id is 0.
|
||||
std::unordered_map<int, size_t> branch_id_to_input_controls_num_;
|
||||
size_t input_controls_num_;
|
||||
|
||||
// The output controls contain the data source actors and the no input kernel actors and output actor.
|
||||
std::vector<AID> data_source_aids_;
|
||||
|
@ -98,10 +92,6 @@ class LoopCountActor : public DebugAwareActor {
|
|||
// The id of recorder actor. Send message to it for clearing recorder info before loop count actor exits.
|
||||
const AID *recorder_aid_;
|
||||
|
||||
// When the result of the graph is sent to the output actor, the gather actor of the graph needs
|
||||
// to send branch_id to the output actor to determine the corresponding weight.
|
||||
int branch_id_{kMainBranchID};
|
||||
|
||||
// The nodes need continuous memory, which must allocate in the begin of step running. The first bool of pair
|
||||
// expresses the inputs of node need continuous memory, the second bool of pair expresses the outputs of node need
|
||||
// continuous memory.
|
||||
|
|
|
@ -44,28 +44,25 @@ TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index,
|
|||
|
||||
void OutputActor::Init() {
|
||||
// Set the number of actor running dependent messages.
|
||||
if ((!need_loop_count_) && (device_tensor_store_keys_.size() == 1)) {
|
||||
running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_[kMainBranchID].size());
|
||||
if ((!need_loop_count_)) {
|
||||
running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size());
|
||||
}
|
||||
}
|
||||
|
||||
void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (branch_id_ == kInvalidBranchID) {
|
||||
MS_LOG(EXCEPTION) << "Invalid branch id for output actor.";
|
||||
}
|
||||
|
||||
current_count_ = loop_count;
|
||||
if (loop_count_ == current_count_) {
|
||||
if (current_outputs_num_ + device_tensor_store_keys_[branch_id_].size() != outputs_num_) {
|
||||
std::string error_info =
|
||||
"The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) +
|
||||
", the current outputs num: " + std::to_string(current_outputs_num_) +
|
||||
", the device tensor store num: " + std::to_string(device_tensor_store_keys_[branch_id_].size());
|
||||
if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) {
|
||||
std::string error_info = "The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) +
|
||||
", the current outputs num: " + std::to_string(current_outputs_num_) +
|
||||
", the device tensor store num: " + std::to_string(device_tensor_store_keys_.size());
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
||||
// Because device tensor store can't send data, so fetch the output result of device tensor store in running end.
|
||||
for (const auto &device_tensor_store_key : device_tensor_store_keys_[branch_id_]) {
|
||||
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
|
||||
if (device_tensor_store_key.first >= outputs_.size()) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range.");
|
||||
}
|
||||
|
@ -108,16 +105,10 @@ void OutputActor::UpdateOutputDeviceAddress() {
|
|||
output_nodes_.resize(outputs_num_);
|
||||
}
|
||||
|
||||
void OutputActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
branch_id_ = branch_id;
|
||||
}
|
||||
|
||||
void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
|
||||
OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
// Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".
|
||||
if (loop_count_ - current_count_ != 1) {
|
||||
return;
|
||||
|
@ -132,7 +123,7 @@ void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_ind
|
|||
// Save the output nodes to clear the device tensor in the running end.
|
||||
output_nodes_[output_position] = KernelWithIndex(output_node, output_index);
|
||||
// There is no loop count actor in step mode, need trigger call CollectLoopCount to replace old output device tensors.
|
||||
if (!need_loop_count_ && (current_outputs_num_ + device_tensor_store_keys_[branch_id_].size() == outputs_num_)) {
|
||||
if (!need_loop_count_ && (current_outputs_num_ + device_tensor_store_keys_.size() == outputs_num_)) {
|
||||
CollectLoopCount(++current_count_, context);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,12 +46,10 @@ class OutputActor : public OpActor<DeviceTensor> {
|
|||
outputs_num_(outputs_num),
|
||||
current_outputs_num_(0),
|
||||
need_loop_count_(need_loop_count),
|
||||
branch_id_(kMainBranchID),
|
||||
running_dependent_msg_num_(1) {
|
||||
outputs_.resize(outputs_num);
|
||||
output_nodes_.resize(outputs_num);
|
||||
device_contexts_.resize(outputs_num);
|
||||
device_tensor_store_keys_[kMainBranchID] = std::vector<std::pair<size_t, AnfNodePtr>>();
|
||||
}
|
||||
~OutputActor() override = default;
|
||||
|
||||
|
@ -65,8 +63,6 @@ class OutputActor : public OpActor<DeviceTensor> {
|
|||
void CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
|
||||
OpContext<DeviceTensor> *context);
|
||||
|
||||
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context);
|
||||
|
||||
// The graph output need be set new device address every step or loop, to avoid that the device address
|
||||
// context of tensor be rewritten in the next step or next loop.
|
||||
void UpdateOutputDeviceAddress();
|
||||
|
@ -88,16 +84,11 @@ class OutputActor : public OpActor<DeviceTensor> {
|
|||
size_t outputs_num_;
|
||||
size_t current_outputs_num_;
|
||||
bool need_loop_count_;
|
||||
int branch_id_;
|
||||
|
||||
// The dependent messages number of actor running.
|
||||
int running_dependent_msg_num_;
|
||||
|
||||
// Pair<branch_id, <index, node>> points to the dependent device tensor store, branch_id is the output branch id.
|
||||
// In general, the branch id is 0, which means there is only one output branch in the actor set. When there are
|
||||
// multiple possible output branches in the actor set, different branch ids correspond to their own related nodes.
|
||||
// The index is the position of node in the output, node is the key of the device tensor store.
|
||||
std::unordered_map<size_t, std::vector<std::pair<size_t, AnfNodePtr>>> device_tensor_store_keys_;
|
||||
std::vector<std::pair<size_t, AnfNodePtr>> device_tensor_store_keys_;
|
||||
};
|
||||
|
||||
using OutputActorPtr = std::shared_ptr<OutputActor>;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/gather_actor.h"
|
||||
#include "runtime/framework/actor/memory_manager_actor.h"
|
||||
#include "mindrt/include/async/async.h"
|
||||
#include "abstract/utils.h"
|
||||
|
@ -39,10 +40,10 @@ void SwitchActor::Init() {
|
|||
|
||||
void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
const auto &sequential_num = context->sequential_num_;
|
||||
auto &input_datas = input_data_[sequential_num];
|
||||
input_datas[input_data->index_].push(input_data->data_);
|
||||
|
||||
// When all the inputs are collected, then allocate memory and callback launch.
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
EraseInput(context);
|
||||
|
@ -50,14 +51,38 @@ void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
|
|||
}
|
||||
}
|
||||
|
||||
void SwitchActor::Initialize() {
|
||||
void SwitchActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
if (input_controls_[sequential_num].find(input_control) == input_controls_[sequential_num].end()) {
|
||||
input_controls_[sequential_num][input_control] = 0;
|
||||
}
|
||||
input_controls_[sequential_num][input_control]++;
|
||||
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
input_branch_ids_[sequential_num].push(branch_id);
|
||||
}
|
||||
|
||||
void SwitchActor::Initialize(const ControlNodeParserPtr &parser) {
|
||||
std::vector<AnfNodePtr> inputs = node_->inputs();
|
||||
|
||||
if (IsPrimitive(inputs[0], prim::kPrimSwitch)) {
|
||||
InitSwitch();
|
||||
} else if (IsPrimitive(inputs[0], prim::kPrimReturn)) {
|
||||
InitReturn(parser);
|
||||
} else {
|
||||
InitSwitchLayer();
|
||||
}
|
||||
backend_parameters_.resize(input_nodes_.size());
|
||||
}
|
||||
|
||||
void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
|
||||
|
@ -88,6 +113,23 @@ void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
|
|||
}
|
||||
}
|
||||
|
||||
void SwitchActor::InitVectorSize(const size_t num) {
|
||||
branch_inputs_pos_.resize(num);
|
||||
branch_func_graph_.resize(num);
|
||||
output_branch_arrows_.resize(num);
|
||||
output_branch_result_arrows_.resize(num);
|
||||
output_branch_control_arrows_.resize(num);
|
||||
output_branch_branch_arrows_.resize(num);
|
||||
}
|
||||
|
||||
void SwitchActor::InitReturn(const ControlNodeParserPtr &parser) {
|
||||
const auto &func_graph = node_->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &call_num = parser->GetCallNumByFuncGraph(func_graph);
|
||||
InitVectorSize(call_num);
|
||||
AddCommonInput(func_graph->output());
|
||||
}
|
||||
|
||||
void SwitchActor::InitSwitch() {
|
||||
// The inputs of the switch node:
|
||||
// [0] ValueNode<Primitive> kSwitch.
|
||||
|
@ -99,14 +141,10 @@ void SwitchActor::InitSwitch() {
|
|||
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is not equal 4";
|
||||
}
|
||||
|
||||
branch_total_inputs_.resize(kSwitchPartialNum);
|
||||
branch_inputs_pos_.resize(kSwitchPartialNum);
|
||||
branch_func_graph_.resize(kSwitchPartialNum);
|
||||
output_branch_arrows_.resize(kSwitchPartialNum);
|
||||
output_branch_result_arrows_.resize(kSwitchPartialNum);
|
||||
output_branch_control_arrows_.resize(kSwitchPartialNum);
|
||||
InitVectorSize(kSwitchPartialNum);
|
||||
|
||||
input_nodes_.push_back(inputs[kSwitchCondPos]);
|
||||
const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0);
|
||||
input_nodes_.push_back(cond_node);
|
||||
input_datas_num_++;
|
||||
// Init the two branches of switch node.
|
||||
InitPartial(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
|
||||
|
@ -123,16 +161,13 @@ void SwitchActor::InitSwitchLayer() {
|
|||
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitchLayer->name() << " is not equal 3";
|
||||
}
|
||||
|
||||
input_nodes_.push_back(inputs[kSwitchLayerCondPos]);
|
||||
const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchLayerCondPos], 0);
|
||||
input_nodes_.push_back(cond_node);
|
||||
input_datas_num_++;
|
||||
|
||||
// The second input of SwitchLayer is maketuple node, which includes all branches.
|
||||
auto branch_nodes = inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
|
||||
branch_total_inputs_.resize(branch_nodes.size() - 1);
|
||||
branch_inputs_pos_.resize(branch_nodes.size() - 1);
|
||||
branch_func_graph_.resize(branch_nodes.size() - 1);
|
||||
output_branch_arrows_.resize(branch_nodes.size() - 1);
|
||||
output_branch_result_arrows_.resize(branch_nodes.size() - 1);
|
||||
output_branch_control_arrows_.resize(branch_nodes.size() - 1);
|
||||
InitVectorSize(branch_nodes.size() - 1);
|
||||
|
||||
// Parse all branches.
|
||||
for (size_t i = 1; i < branch_nodes.size(); ++i) {
|
||||
|
@ -151,41 +186,92 @@ void SwitchActor::AddCommonInput(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
size_t SwitchActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
|
||||
const auto &iter = find(input_nodes_.begin(), input_nodes_.end(), data_node);
|
||||
const auto data_node_with_index = AnfAlgo::VisitKernelWithReturnType(data_node, 0);
|
||||
const auto &iter = find(input_nodes_.begin(), input_nodes_.end(), data_node_with_index);
|
||||
if (iter == input_nodes_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope()
|
||||
<< " is not exist in gather actor:" << GetAID();
|
||||
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node)
|
||||
<< " is not exist in switch actor:" << GetAID();
|
||||
}
|
||||
return iter - input_nodes_.begin();
|
||||
}
|
||||
|
||||
void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
|
||||
branch_total_inputs_[branch].push_back(node);
|
||||
void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t branch) {
|
||||
const auto &node = node_with_index.first;
|
||||
|
||||
if (node->isa<ValueNode>() && (!HasAbstractMonad(node))) {
|
||||
// Add weight and value node.
|
||||
if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && HasAbstractRef(node)) || node->isa<ValueNode>()) {
|
||||
const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
|
||||
if (iter != input_nodes_.end()) {
|
||||
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
|
||||
return;
|
||||
}
|
||||
device_tensor_store_keys_.push_back({input_nodes_.size(), node.get()});
|
||||
branch_inputs_pos_[branch].push_back(input_nodes_.size());
|
||||
input_nodes_.push_back(node);
|
||||
input_nodes_.push_back(node_with_index);
|
||||
return;
|
||||
}
|
||||
|
||||
// Switch actor only receives parameter, updatestate node output is U, need to be skipped.
|
||||
if (IsPersistentDeviceTensor(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
|
||||
// Output of updatestate node is U, need to be skipped.
|
||||
if (HasAbstractRef(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto iter = find(input_nodes_.begin(), input_nodes_.end(), node);
|
||||
// Add parameter.
|
||||
auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
|
||||
if (iter == input_nodes_.end()) {
|
||||
branch_inputs_pos_[branch].push_back(input_nodes_.size());
|
||||
input_nodes_.push_back(node);
|
||||
input_nodes_.push_back(node_with_index);
|
||||
++input_datas_num_;
|
||||
} else {
|
||||
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
|
||||
}
|
||||
}
|
||||
|
||||
size_t SwitchActor::GetIndex() {
|
||||
void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState) || HasAbstractMonad(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto &real_input = AnfAlgo::VisitKernelWithReturnType(node, 0);
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(real_input.first, prim::kPrimMakeTuple)) {
|
||||
const auto &inputs = real_input.first->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
|
||||
AddInput(inputs[i], branch);
|
||||
}
|
||||
} else if (IsCallNode(real_input.first)) {
|
||||
std::vector<AnfNodePtr> call_nodes;
|
||||
const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes);
|
||||
|
||||
if (call_output_num <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first);
|
||||
}
|
||||
for (size_t i = 0; i < call_output_num; ++i) {
|
||||
AddInput({real_input.first, i}, branch);
|
||||
}
|
||||
} else {
|
||||
AddInput(real_input, branch);
|
||||
}
|
||||
}
|
||||
|
||||
size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
|
||||
if (need_branch_id_input_) {
|
||||
if (input_branch_ids_.find(context->sequential_num_) == input_branch_ids_.end() ||
|
||||
input_branch_ids_[context->sequential_num_].empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid branch id for actor:" << GetAID();
|
||||
}
|
||||
size_t branch_id = input_branch_ids_[context->sequential_num_].top();
|
||||
input_branch_ids_[context->sequential_num_].pop();
|
||||
if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid branch id for switch actor:" << GetAID() << " branch id:" << branch_id;
|
||||
}
|
||||
return branch_id_to_index_[branch_id];
|
||||
}
|
||||
|
||||
DeviceTensor *device_tensor = input_device_tensors_[0];
|
||||
if (device_tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Index of switch actor is empty:" << GetAID();
|
||||
}
|
||||
auto inputs = node_->inputs();
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0);
|
||||
size_t size = abstract::TypeIdSize(type_id);
|
||||
|
@ -219,28 +305,46 @@ size_t SwitchActor::GetIndex() {
|
|||
bool SwitchActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter == input_op_datas_.end()) {
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter == input_data_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_datas_num_) {
|
||||
return false;
|
||||
}
|
||||
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[](const auto &input_stack) { return input_stack.second.empty(); })) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
auto data_iter = input_controls_.find(context->sequential_num_);
|
||||
if (data_iter == input_controls_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_controls_num_) {
|
||||
return false;
|
||||
}
|
||||
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[](const auto &input_stack) { return input_stack.second == 0; })) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
input_device_tensors_.resize(input_nodes_.size());
|
||||
auto data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter != input_op_datas_.end()) {
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter != input_data_.end()) {
|
||||
for (auto &input_data : data_iter->second) {
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
input_device_tensors_[input_data->index_] = input_data->data_;
|
||||
input_device_tensors_[input_data.first] = input_data.second.top();
|
||||
input_data.second.pop();
|
||||
}
|
||||
}
|
||||
data_iter->second.clear();
|
||||
|
||||
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
|
||||
auto device_tensor =
|
||||
|
@ -253,15 +357,28 @@ void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
|
||||
}
|
||||
|
||||
auto control_iter = input_controls_.find(context->sequential_num_);
|
||||
if (control_iter != input_controls_.end()) {
|
||||
for_each(control_iter->second.begin(), control_iter->second.end(),
|
||||
[](auto &input_control) { input_control.second--; });
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto index = GetIndex();
|
||||
auto index = GetIndex(context);
|
||||
if (index >= output_branch_arrows_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Switch actor invalid index:" << index;
|
||||
}
|
||||
|
||||
if (local_branch_id_ >= 0) {
|
||||
const auto &branch_arrows = output_branch_branch_arrows_[index];
|
||||
for (const auto &branch_arrow : branch_arrows) {
|
||||
Async(branch_arrow, &GatherActor::CollectBranchId, local_branch_id_, context);
|
||||
}
|
||||
}
|
||||
|
||||
auto &output_branch_arrow = output_branch_arrows_[index];
|
||||
auto &output_data = output_data_[index];
|
||||
for (size_t i = 0; i < output_branch_arrow.size(); ++i) {
|
||||
|
@ -270,6 +387,7 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
data->data_ = input_device_tensors_[data_arrow->from_output_index_];
|
||||
|
||||
Async(data_arrow->to_op_id_, &OpActor::RunOpData, data.get(), context);
|
||||
}
|
||||
|
||||
|
@ -279,7 +397,7 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
auto &result_arrow = output_branch_result_arrow[i];
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
size_t from_index = result_arrow->from_output_index_;
|
||||
for (const auto &backend_node : front_to_backend_parameter_[from_index]) {
|
||||
for (const auto &backend_node : backend_parameters_[from_index]) {
|
||||
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() ==
|
||||
input_device_tensors_[from_index]) {
|
||||
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
|
||||
|
@ -298,8 +416,10 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
|
||||
void SwitchActor::EraseInput(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto ret = input_op_datas_.erase(context->sequential_num_);
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter != input_data_.end() && std::all_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[](const auto &input_data) { return input_data.second.empty(); })) {
|
||||
auto ret = input_data_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input data failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
|
@ -307,10 +427,15 @@ void SwitchActor::EraseInput(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
auto ret = input_op_controls_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input controls failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
auto control_iter = input_controls_.find(context->sequential_num_);
|
||||
if (control_iter != input_controls_.end() &&
|
||||
std::all_of(control_iter->second.begin(), control_iter->second.end(),
|
||||
[](const auto &input_control) { return input_control.second == 0; })) {
|
||||
auto ret = input_controls_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input control failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -319,37 +444,20 @@ void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
|
|||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context);
|
||||
}
|
||||
|
||||
void SwitchActor::FetchInputNode(const std::vector<AnfNodePtr> &origin_parameters_order,
|
||||
const FrontToBackendNodeWithContext &front_to_backend_parameters,
|
||||
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel) {
|
||||
front_to_backend_parameter_.resize(input_nodes_.size());
|
||||
|
||||
void SwitchActor::FetchInputNode(const ControlNodeParserPtr &parser) {
|
||||
for (size_t i = 0; i < input_nodes_.size(); ++i) {
|
||||
const auto &input_node = input_nodes_[i];
|
||||
if (input_node->isa<ValueNode>()) {
|
||||
front_to_backend_parameter_[i].push_back({input_node, 0});
|
||||
} else if (input_node->isa<Parameter>()) {
|
||||
if (front_to_backend_parameters.find(input_node) != front_to_backend_parameters.end()) {
|
||||
const auto backend_node = front_to_backend_parameters.at(input_node).first;
|
||||
front_to_backend_parameter_[i].push_back({backend_node, 0});
|
||||
}
|
||||
} else if (input_node->isa<CNode>()) {
|
||||
if (IsCallNode(input_node)) {
|
||||
const auto func_graphs = FetchFuncGraphbyCallNode(input_node->cast<CNodePtr>());
|
||||
for (const auto func_graph : func_graphs) {
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
front_to_backend_parameter_[i].push_back({func_graph->output(), 0});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
|
||||
if (front_to_backend_kernel.find(input_node) != front_to_backend_kernel.end()) {
|
||||
front_to_backend_parameter_[i].emplace_back(kernel_with_index);
|
||||
}
|
||||
}
|
||||
const auto &input_node = input_nodes_[i].first;
|
||||
if (!HasAbstractRef(input_node)) {
|
||||
backend_parameters_[i] = parser->FetchBackendInputNodeByFrontNode(input_node);
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &backend_weight = parser->FetchBackendNodebyWeightNode(input_node);
|
||||
if (backend_weight == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for weight node:" << AnfAlgo::GetNodeDebugString(input_node);
|
||||
}
|
||||
backend_parameters_[i].push_back({backend_weight, 0});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -58,16 +58,25 @@ constexpr size_t kMakeTupleInputStartPos = 1;
|
|||
// 5. Free Memory
|
||||
class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
||||
public:
|
||||
SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node)
|
||||
: SwitchActorBase(name), device_context_(device_context), node_(node) {}
|
||||
SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node, const int branch_id,
|
||||
const bool need_branch_id_input)
|
||||
: SwitchActorBase(name),
|
||||
device_context_(device_context),
|
||||
node_(node),
|
||||
local_branch_id_(branch_id),
|
||||
need_branch_id_input_(need_branch_id_input) {}
|
||||
~SwitchActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
||||
// The switch actor run when receive the input data.
|
||||
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context);
|
||||
// The switch actor run when receive the input control.
|
||||
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context);
|
||||
// The switch actor run when receive the input branch id.
|
||||
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context);
|
||||
// Initialize the input and output information of the switch actor According to node_.
|
||||
void Initialize();
|
||||
void Initialize(const ControlNodeParserPtr &parser);
|
||||
// Add input for all branches.
|
||||
void AddCommonInput(const AnfNodePtr &node);
|
||||
// Fetch the input position of the data node.
|
||||
|
@ -79,11 +88,16 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
void InitPartial(const AnfNodePtr &node, const size_t branch_id);
|
||||
void InitSwitch();
|
||||
void InitSwitchLayer();
|
||||
|
||||
// In control flow, the output of each subgraph is connected to a switch actor, and the switch actor is
|
||||
// initialized with the return node of the subgraph.
|
||||
void InitReturn(const ControlNodeParserPtr &parser);
|
||||
// Initialize the size of the vector members.
|
||||
void InitVectorSize(const size_t num);
|
||||
// Get index from DeviceTensor.
|
||||
size_t GetIndex();
|
||||
size_t GetIndex(OpContext<DeviceTensor> *context);
|
||||
// Add input for the branch.
|
||||
void AddInput(const AnfNodePtr &node, size_t branch);
|
||||
void AddInput(const KernelWithIndex node_with_index, const size_t branch);
|
||||
|
||||
// Check whether satisfy the condition for send outputs.
|
||||
bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const;
|
||||
|
@ -95,12 +109,10 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
void SendMemoryFreeReq(OpContext<DeviceTensor> *context);
|
||||
|
||||
// Collect all the backend inputs of switch actor.
|
||||
void FetchInputNode(const std::vector<AnfNodePtr> &origin_parameters_order,
|
||||
const FrontToBackendNodeWithContext &front_to_backend_parameters,
|
||||
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel);
|
||||
// All inputs of the switch actor, excluding weight and tensor.
|
||||
void FetchInputNode(const ControlNodeParserPtr &parser);
|
||||
// All inputs of the switch actor, include weight and tensor.
|
||||
// Used to receive input data, the first input is the condition of switch.
|
||||
std::vector<AnfNodePtr> input_nodes_;
|
||||
std::vector<KernelWithIndex> input_nodes_;
|
||||
// The position of the branch output in the input_nodes_.
|
||||
std::vector<std::vector<size_t>> branch_inputs_pos_;
|
||||
|
||||
|
@ -126,9 +138,9 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
|
||||
// When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor,
|
||||
// so all the nodes that may send the device tensor to switch actor are recorded.
|
||||
std::vector<std::vector<KernelWithIndex>> front_to_backend_parameter_;
|
||||
std::vector<std::vector<KernelWithIndex>> backend_parameters_;
|
||||
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
|
||||
|
||||
std::vector<FuncGraphPtr> branch_func_graph_;
|
||||
|
||||
std::unordered_map<int, size_t> branch_id_to_index_;
|
||||
|
@ -148,8 +160,12 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
// The dependent input controls number.
|
||||
size_t input_controls_num_{0};
|
||||
CNodePtr node_;
|
||||
|
||||
// The branch id corresponding to the funcgraph to which the gather actor belongs.
|
||||
int local_branch_id_;
|
||||
size_t input_branch_id_num_;
|
||||
// Whether it needs to accept the branch id. When the switch actor is the output of the subgraph, it needs to receive
|
||||
// branch id sent by the gather actor of subgraph, which will be true at this time.
|
||||
bool need_branch_id_input_;
|
||||
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;
|
||||
|
|
|
@ -547,6 +547,8 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) {
|
|||
|
||||
void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph) {
|
||||
root_func_graph_ = root_graph;
|
||||
|
||||
root_graph_parameters_ = root_graph->parameters();
|
||||
|
||||
CreateBranchIDForFuncGraph(control_nodes);
|
||||
|
@ -598,6 +600,27 @@ bool ControlNodeParser::IsCallInputKernelGraph(const KernelGraphPtr &graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ControlNodeParser::IsKernelInRootFuncGraph(const AnfNodePtr &kernel) {
|
||||
if (kernel == nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto &graph = kernel->func_graph();
|
||||
if (kernel != nullptr && graph != nullptr) {
|
||||
const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
|
||||
if (kernel_graph == nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto func_graph = kernel_graph->GetFuncGraph();
|
||||
if (func_graph != nullptr && func_graph != root_func_graph_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t ControlNodeParser::GetCallNumByFuncGraph(const FuncGraphPtr &func_graph) {
|
||||
if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid funcgraph:" << func_graph->ToString();
|
||||
|
@ -622,6 +645,21 @@ DeviceContext *ControlNodeParser::GetFrontValueNodeDeviceContext(const AnfNodePt
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr ControlNodeParser::FetchBackendNodebyWeightNode(const AnfNodePtr &node) {
|
||||
for (const auto &host_parameter_to_weight : host_parameter_to_weights_) {
|
||||
for (const auto &front_weight : host_parameter_to_weight.second) {
|
||||
if (front_weight == node) {
|
||||
const auto &iter = front_to_backend_parameters_.find(front_weight);
|
||||
if (iter != front_to_backend_parameters_.end()) {
|
||||
return iter->second.first;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node,
|
||||
std::vector<AnfNodePtr> *value_nodes) {
|
||||
const auto &cnode = switch_node->cast<CNodePtr>();
|
||||
|
@ -928,6 +966,40 @@ std::vector<AnfNodePtr> FetchInputParameterbyControlNode(const AnfNodePtr &node,
|
|||
return parameters;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph) {
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
const auto &graph_parameters = graph->input_nodes();
|
||||
|
||||
for (const auto &graph_parameter : graph_parameters) {
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(graph_parameter);
|
||||
if (front_node != nullptr) {
|
||||
parameters.emplace_back(front_node);
|
||||
continue;
|
||||
}
|
||||
const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(graph_parameter);
|
||||
if (front_node_with_index.first == nullptr) {
|
||||
MS_LOG(WARNING) << "Invalid parameter of kernel graph, parameter :"
|
||||
<< AnfAlgo::GetNodeDebugString(graph_parameter);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (HasAbstractRef(AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, 0).first) ||
|
||||
HasAbstractMonad(front_node_with_index.first)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(front_node_with_index.first, prim::kPrimMakeTuple)) {
|
||||
const auto &sub_parameters = FetchInputsByMakeTuple(front_node_with_index.first);
|
||||
parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
|
||||
continue;
|
||||
}
|
||||
|
||||
parameters.emplace_back(front_node_with_index.first);
|
||||
}
|
||||
|
||||
return parameters;
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<AnfNodePtr> &control_nodes) {
|
||||
|
@ -939,7 +1011,7 @@ void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGra
|
|||
for (size_t i = 0; i < graphs.size(); ++i) {
|
||||
const auto &graph = graphs[i];
|
||||
auto device_context = device_contexts[i];
|
||||
for (const auto ¶meter : graph->parameters()) {
|
||||
for (const auto ¶meter : graph->input_nodes()) {
|
||||
auto front_node = graph->GetFrontAnfByBackendAnf(parameter);
|
||||
|
||||
if (front_node != nullptr && front_node->isa<Parameter>() &&
|
||||
|
|
|
@ -71,8 +71,8 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
|
|||
// Find all funcgraphs that the call node will call.
|
||||
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node);
|
||||
|
||||
// Recursive interface, get all input of make tuple node.
|
||||
std::vector<AnfNodePtr> FetchInputsByMakeTuple(const AnfNodePtr &node);
|
||||
// Get parameters in kernel graph.
|
||||
std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph);
|
||||
|
||||
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
|
||||
class ControlNodeParser {
|
||||
|
@ -107,6 +107,15 @@ class ControlNodeParser {
|
|||
// Check whether there is a call node in the front input nodes of the kernel graph.
|
||||
bool IsCallInputKernelGraph(const KernelGraphPtr &graph);
|
||||
|
||||
// Check whether the kernel actor belongs to the root graph.
|
||||
// In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output should
|
||||
// be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be sent to the
|
||||
// output switch actor.
|
||||
bool IsKernelInRootFuncGraph(const AnfNodePtr &kernel);
|
||||
|
||||
// Get the backend node corresponding to the weight node in the subgraph.
|
||||
AnfNodePtr FetchBackendNodebyWeightNode(const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
|
@ -195,10 +204,12 @@ class ControlNodeParser {
|
|||
std::vector<AnfNodePtr> control_node_parameters_;
|
||||
// The number of calls to func_graph.
|
||||
std::unordered_map<FuncGraphPtr, size_t> func_graph_to_call_num_;
|
||||
// The kernel graph of call exists in the front-end input node.
|
||||
std::unordered_map<KernelGraphPtr, DeviceContext *> call_input_kernel_graphs_;
|
||||
// The kernel graph of call exists in the front input node.
|
||||
// In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a gather actor
|
||||
// is created for kernel graph which has a call input.
|
||||
std::unordered_map<KernelGraphPtr, DeviceContext *> call_input_kernel_graphs_;
|
||||
// Root funcgraph and its parameters.
|
||||
FuncGraphPtr root_func_graph_;
|
||||
std::vector<AnfNodePtr> root_graph_parameters_;
|
||||
};
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -46,7 +46,7 @@ using mindspore::session::KernelWithIndex;
|
|||
// Position of kernel with index, the value pair<branch_id, vector<pos>> means the branch id of the kernel and the pos
|
||||
// of the kernel. Generally, there is only one branch, and the branch id is 0 at this time. In control flow, there are
|
||||
// multiple branch scenarios, and pos represents the position of the kernel in the branch.
|
||||
using KernelMapPosition = std::map<KernelWithIndex, std::pair<int, std::vector<size_t>>, session::KernelWithIndexCmp>;
|
||||
using KernelMapPosition = std::map<KernelWithIndex, std::vector<size_t>, session::KernelWithIndexCmp>;
|
||||
using ActorInfo = std::string;
|
||||
|
||||
// The second element of pair represents the output index of op actor corresponding to the graph output node.
|
||||
|
@ -212,7 +212,8 @@ class GraphScheduler {
|
|||
KernelWithIndex to_kernel_with_input_idx);
|
||||
|
||||
// 2. The processing of linking control arrows.
|
||||
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set);
|
||||
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
|
||||
const ControlNodeParserPtr &parser);
|
||||
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node);
|
||||
// The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
|
||||
void LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node);
|
||||
|
@ -227,20 +228,24 @@ class GraphScheduler {
|
|||
|
||||
// 4. The processing of control flow linking.
|
||||
void LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *actor_set);
|
||||
void LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx);
|
||||
void LinkDataArrowForGatherActor(GatherActor *from_actor, const AnfNodePtr &front_node, KernelActor *to_actor,
|
||||
const size_t to_index);
|
||||
void LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, SwitchActor *actor);
|
||||
// Connect the input of the actor.
|
||||
void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const AnfNodePtr &input_node,
|
||||
OpActor<DeviceTensor> *to_actor, const size_t to_index);
|
||||
void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const KernelWithIndex &input_node,
|
||||
const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *to_actor,
|
||||
const size_t to_index);
|
||||
// When the input of the actor is a call node, the output of the funcgraph called by the call node needs to be
|
||||
// connected.
|
||||
void LinkDataArrowByCallInput(const GraphCompilerInfo &graph_compiler_info, const AnfNodePtr &call_node,
|
||||
OpActor<DeviceTensor> *to_actor, const size_t to_index);
|
||||
void LinkDataArrowForSwitchActor(SwitchActor *from_actor, KernelActor *to_actor, const size_t to_index);
|
||||
void LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors, LoopCountActor *to_actor,
|
||||
const std::vector<KernelGraphPtr> &graphs);
|
||||
void LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index, const ControlNodeParserPtr &parser,
|
||||
const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *to_actor,
|
||||
const size_t to_index);
|
||||
void LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index, OpActor<DeviceTensor> *to_actor,
|
||||
const size_t to_index, const size_t branch_index = SIZE_MAX);
|
||||
void LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors,
|
||||
std::vector<KernelActorPtr> *kernel_actors, LoopCountActor *to_actor,
|
||||
const std::vector<KernelGraphPtr> &graphs, const ControlNodeParserPtr &parser);
|
||||
|
||||
void LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *switch_actors, LoopCountActor *to_actor,
|
||||
const KernelMapPosition &origin_outputs_order);
|
||||
// In control flow, there are scenarios where there are multi-branch outputs, and the gather actor needs to
|
||||
|
|
|
@ -712,26 +712,25 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
|||
auto parser = std::make_shared<ControlNodeParser>();
|
||||
parser->Parse(control_nodes_, graphs, device_contexts, root_graph);
|
||||
|
||||
// Get all the outputs. In control flow, there may be multiple branch output.
|
||||
runtime::KernelMapPosition outputs_order;
|
||||
size_t outputs_num = 0;
|
||||
const auto &all_branch_output = parser->FetchAllBranchOutputs(root_graph);
|
||||
for (int j = 0; j < SizeToInt(all_branch_output.size()); ++j) {
|
||||
// In general, there is only one output branch, and the branch id is 0 at this time. In the control flow,
|
||||
// there are multi-branch output scenarios. Different branches may have different weight nodes. When output
|
||||
// actor run, the corresponding weight node needs to be obtained according to different branches. Therefore,
|
||||
// the branch of the output nodes needs to be recorded.
|
||||
const int branch_id = ((all_branch_output.size() == 1 ? runtime::kMainBranchID : (j + runtime::kSubBranchStartID)));
|
||||
const auto &branch_output = all_branch_output[j];
|
||||
size_t position = 0;
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(branch_output);
|
||||
outputs_num = outputs.size();
|
||||
for (const auto &output : outputs) {
|
||||
if (outputs_order.count(output) == 0) {
|
||||
outputs_order[output] = {branch_id, {position++}};
|
||||
} else {
|
||||
outputs_order[output].second.emplace_back(position++);
|
||||
}
|
||||
const auto &root_output =
|
||||
AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
|
||||
size_t position = 0;
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
|
||||
if (runtime::IsCallNode(root_output)) {
|
||||
std::vector<AnfNodePtr> call_nodes;
|
||||
size_t call_output_num = runtime::FetchOutputSizebyCallNode(root_output, &call_nodes);
|
||||
for (size_t i = 0; i < call_output_num; ++i) {
|
||||
outputs.push_back({root_output, i});
|
||||
}
|
||||
}
|
||||
outputs_num = outputs.size();
|
||||
for (const auto &output : outputs) {
|
||||
if (outputs_order.count(output) == 0) {
|
||||
outputs_order[output] = {position++};
|
||||
} else {
|
||||
outputs_order[output].emplace_back(position++);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -759,9 +758,9 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
|
|||
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
for (const auto &output : outputs) {
|
||||
if (outputs_order.count(output) == 0) {
|
||||
outputs_order[output] = {runtime::kMainBranchID, {position++}};
|
||||
outputs_order[output] = {position++};
|
||||
} else {
|
||||
outputs_order[output].second.emplace_back(position++);
|
||||
outputs_order[output].emplace_back(position++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue