!23970 Add control flow actor.

Merge pull request !23970 from gaoyong10/runtime_second
This commit is contained in:
i-robot 2021-09-26 03:42:23 +00:00 committed by Gitee
commit 3f81a92ed4
16 changed files with 508 additions and 1913 deletions

View File

@ -56,7 +56,13 @@ enum class KernelTransformType {
kOutputActor,
kDeviceTensorStore,
// Internal parameter is the output of previous kernel graph which is related to the input of next kernel graph.
kInternalParameter
kInternalParameter,
// Control flow actor type.
kSwitchActor,
kGatherActor,
kEntranceActor,
kExitActor,
kStackActor
};
#define SET_OPCONTEXT_FAIL_RET_WITH_ERROR(op_context, message) \

View File

@ -0,0 +1,71 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_ENTRANCE_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_ENTRANCE_ACTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <unordered_map>
#include <stack>
#include <queue>
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/abstract_actor.h"
namespace mindspore {
namespace runtime {
// Entrance actor is used in the control flow to receive a set of result arrow and a branch id and then send
// the data to the corresponding actor. It is the entry point for subgraph execution.
class EntranceActor : public AbstractActor {
public:
EntranceActor(const std::string &name, const std::vector<AnfNodePtr> &parameters)
: AbstractActor(name, KernelTransformType::kEntranceActor, nullptr), formal_parameters_(parameters) {
device_contexts_.resize(parameters.size());
}
~EntranceActor() override = default;
void Init() override;
// The entrance actor run when receive the input control.
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
// The entrance actor run when receive the real parameter nodes and branch id.
void CollectRealParametersAndBranchId(const std::vector<KernelWithIndex> &real_parameters, int branch_id,
OpContext<DeviceTensor> *const context);
private:
friend class GraphScheduler;
void SendOutput(OpContext<DeviceTensor> *const context) const;
// Formal parameters of actor, which is the front node.
std::vector<KernelWithIndex> formal_parameters_;
// Input data.
std::unordered_map<uuids::uuid *, std::queue<std::vector<KernelWithIndex>>> input_nodes_;
std::unordered_map<uuids::uuid *, std::queue<int>> input_branch_ids_;
std::vector<AID> output_branch_id_arrows_;
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<OpData<DeviceTensor> *> output_data_;
bool is_actor_ready_{true};
};
using EntranceActorPtr = std::shared_ptr<EntranceActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_ENTRANCE_ACTOR_H_

View File

@ -0,0 +1,71 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_EXIT_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_EXIT_ACTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <unordered_map>
#include <stack>
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/abstract_actor.h"
namespace mindspore {
namespace runtime {
// The exit actor is used to receive a set of result arrow and a branch id in the control flow, and then send the
// node in the result to the corresponding actor. It is the exit of the end of subgraph execution.
class ExitActor : public AbstractActor {
public:
ExitActor(const std::string &name, const std::vector<AnfNodePtr> &parameters)
: AbstractActor(name, KernelTransformType::kExitActor, nullptr), formal_parameters_(parameters) {}
~ExitActor() override = default;
// The exit actor run when receive the anfnode.
void CollectRealParameter(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
OpContext<DeviceTensor> *const context);
// The exit actor run when receive the input branch id.
void CollectBranchId(int branch_id, OpContext<DeviceTensor> *const context);
private:
friend class GraphScheduler;
void SendOutput(OpContext<DeviceTensor> *const context) const;
// Formal parameters of actor, which is the front node.
std::vector<KernelWithIndex> formal_parameters_;
// Input data.
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, KernelWithIndex>> input_nodes_;
// Branch ids is used to record the id corresponding to the output branch.
// In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
// places. Therefore, the output of each subgraph will be connected to a exit actor, and the caller will send
// its branch id to the entrance actor of the subgraph. Then branch id will be sent by the entrance actor to
// the exit actor connected to the output.
// In a recursive scenario, the exit will sequentially receive the branch ids sent by the caller, and the exit
// actor needs to store the branch ids in the stack, and pop up in turn when returning.
std::unordered_map<uuids::uuid *, std::stack<int>> input_branch_ids_;
// Output arrow.
std::unordered_map<int, std::vector<DataArrowPtr>> output_branch_result_arrows_;
};
using ExitActorPtr = std::shared_ptr<ExitActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_EXIT_ACTOR_H_

View File

@ -0,0 +1,21 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/actor/control_flow/gather_actor.h"
namespace mindspore {
namespace runtime {} // namespace runtime
} // namespace mindspore

View File

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_GATHER_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_GATHER_ACTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <unordered_map>
#include <stack>
#include <utility>
#include <algorithm>
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/abstract_actor.h"
namespace mindspore {
namespace runtime {
// Gather actor will be used in the control flow. When the subgraph is called, the real parameters need to be put
// together and sent to the subgraph.
class GatherActor : public AbstractActor {
public:
GatherActor(const std::string &name, const std::vector<KernelWithIndex> &parameters)
: AbstractActor(name, KernelTransformType::kGatherActor, nullptr), formal_parameters_(parameters) {}
~GatherActor() override = default;
// The gather actor collects single node when receive the result of kernel actor.
void CollectRealParameter(const AnfNodePtr &node, size_t index, size_t position,
OpContext<DeviceTensor> *const context);
// The gather actor collects all real parameters when receive the output of switch actor.
void CollectRealParameters(const std::vector<KernelWithIndex> &real_parameters, size_t position,
OpContext<DeviceTensor> *const context);
private:
friend class GraphScheduler;
void SendOutput(OpContext<DeviceTensor> *const context) const;
// Formal parameters of actor, which is the front node.
std::vector<KernelWithIndex> formal_parameters_;
// Input data.
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::vector<KernelWithIndex>>> input_nodes_;
// The store node records the value node input of the gather actor.
std::vector<std::pair<size_t, KernelWithIndex>> store_nodes_;
// Output arrow.
std::unordered_map<AnfNodePtr, std::pair<AID, size_t>> output_branch_arrows_;
};
using GatherActorPtr = std::shared_ptr<GatherActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_GATHER_ACTOR_H_

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <unordered_map>
#include <stack>
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/abstract_actor.h"
namespace mindspore {
namespace runtime {
// Stack actor is used to record those device actors that need additional storage in recursive scenes.
class StackActor : public MemoryAwareActor {
public:
StackActor(const std::string &name, const std::vector<KernelWithIndex> &parameters)
: AbstractActor(name, KernelTransformType::kStackActor, nullptr), formal_parameters_(parameters) {
device_contexts_.resize(parameters.size());
}
~StackActor() override = default;
void Init() override;
// The stack actor run when receive the real parameter nodes.
void CollectRealParameter(const AnfNodePtr &node, size_t index, size_t position,
OpContext<DeviceTensor> *const context);
private:
friend class GraphScheduler;
void SendOutput(OpContext<DeviceTensor> *const context) const;
// Formal parameters record the input front-end node, these nodes may be parameter, kernel, call node.
std::vector<KernelWithIndex> formal_parameters_;
// The backend parameter is used to save the backend node corresponding to the device tensor in the stack.
// When these device tensors are used as output, they need to be placed in the node of the result arrow,
// so these nodes need to be saved.
std::vector<KernelWithIndex> backend_parameters_;
// Input data.
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, KernelWithIndex>> input_nodes_;
// The input data records that the stack actor is copied from the input nodes and needs to be stored in the
// device tensor in the stack. This part of the device tensor does not belong to any node, and it will be
// cleaned up directly after the stack is popped.
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
};
using StackActorPtr = std::shared_ptr<StackActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_

View File

@ -0,0 +1,102 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/actor/control_flow/switch_actor.h"
#include "runtime/framework/actor/control_flow/gather_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "mindrt/include/async/async.h"
#include "abstract/utils.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace runtime {
void SwitchActor::Init() {
// Init output data.
output_data_.resize(output_branch_data_arrows_.size());
for (size_t i = 0; i < output_branch_data_arrows_.size(); ++i) {
auto &output_branch_data_arrows = output_branch_data_arrows_[i];
auto &output_data = output_data_[i];
for (auto &data_arrow : output_branch_data_arrows) {
MS_EXCEPTION_IF_NULL(data_arrow);
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
(void)output_data.emplace_back(std::move(data));
}
}
}
size_t SwitchActor::GetIndex(const OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
const auto &nodes_iter = input_nodes_.find(context->sequential_num_);
if (nodes_iter == input_nodes_.end()) {
MS_LOG(ERROR) << "Cannot find input node for switch actor:" << GetAID();
return 0;
}
const auto &index_iter = nodes_iter->second.find(0);
if (index_iter == nodes_iter->second.end() || index_iter->second.empty()) {
MS_LOG(ERROR) << "Cannot find index input node for switch actor:" << GetAID();
return 0;
}
const auto &index_node_with_index = index_iter->second[0];
const auto &index_node = index_node_with_index.first;
MS_EXCEPTION_IF_NULL(index_node);
MS_EXCEPTION_IF_NULL(index_node->kernel_info());
if (!AnfAlgo::OutputAddrExist(index_node, index_node_with_index.second, false)) {
MS_LOG(ERROR) << "Invalid output index:" << index_node_with_index.second
<< " for node:" << index_node->DebugString();
return 0;
}
DeviceTensor *device_tensor = AnfAlgo::GetMutableOutputAddr(index_node, index_node_with_index.second, false).get();
MS_EXCEPTION_IF_NULL(device_tensor);
TypeId type_id = AnfAlgo::GetOutputInferDataType(index_node, index_node_with_index.second);
size_t size = abstract::TypeIdSize(type_id);
if (size > sizeof(int64_t)) {
MS_LOG(ERROR) << "Index must be Int type.";
return 0;
}
int64_t index = 0;
char buf[kMaxSwitchCondSize] = {0};
ShapeVector host_shape;
if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) {
MS_LOG(ERROR) << GetAID().Name() << " get index from device address failed, type id:" << std::to_string(type_id)
<< ", device type:" << std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceAddressType()));
return 0;
}
if (type_id == TypeId::kNumberTypeInt32) {
index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
} else if (type_id == TypeId::kNumberTypeInt64) {
index = (static_cast<int64_t *>(static_cast<void *>(buf)))[0];
} else if (type_id == TypeId::kNumberTypeBool) {
bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
index = static_cast<int64_t>(cond ? 1 : 0);
} else {
MS_LOG(ERROR) << "Index must be Int type.";
return 0;
}
// SwitchLayer node support negative index range [-size, -1].
if (index < 0) {
index += SizeToInt(input_result_num_ - 1);
}
return static_cast<size_t>(index);
}
} // namespace runtime
} // namespace mindspore

View File

@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_SWITCH_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_SWITCH_ACTOR_H_
#include <vector>
#include <string>
#include <unordered_map>
#include <memory>
#include <utility>
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/abstract_actor.h"
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::session::KernelWithIndex;
constexpr size_t kSwitchCondPos = 1;
constexpr size_t kMaxSwitchCondSize = 8;
// Switch actor is used to execute the branch according to the input condition.
// Switch and SwitchLayer node will be converted to switch actor.
class SwitchActor : public AbstractActor {
public:
SwitchActor(const std::string &name, const std::vector<KernelWithIndex> &parameters)
: AbstractActor(name, KernelTransformType::kSwitchActor, nullptr), formal_parameters_(parameters) {
input_result_num_ = formal_parameters_.size();
}
~SwitchActor() override = default;
void Init() override;
// The switch actor collects single node when receive the result of kernel actor.
void CollectRealParameter(const AnfNodePtr &node, size_t index, size_t position,
OpContext<DeviceTensor> *const context);
// The switch actor collects all real parameters when receive the output of gather actor.
void CollectRealParameters(const std::vector<KernelWithIndex> &real_parameters, size_t position,
OpContext<DeviceTensor> *const context);
private:
friend class GraphScheduler;
size_t GetIndex(const OpContext<DeviceTensor> *const context);
// Formal parameters of actor, which is the front node.
std::vector<KernelWithIndex> formal_parameters_;
// Input data.
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::vector<KernelWithIndex>>> input_nodes_;
// The store node records the value node input of the switch actor.
std::vector<std::pair<size_t, AnfNodePtr>> store_nodes_;
// Output arrow.
std::vector<std::vector<DataArrowPtr>> output_branch_data_arrows_;
std::vector<std::vector<DataArrowPtr>> output_branch_result_arrows_;
std::vector<AID> output_branch_real_parameter_arrows_;
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;
size_t input_result_num_;
};
using SwitchActorPtr = std::shared_ptr<SwitchActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_SWITCH_ACTOR_H_

View File

@ -1,250 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/actor/gather_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "runtime/framework/actor/loop_count_actor.h"
#include "mindrt/include/async/async.h"
#include "abstract/utils.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace runtime {
void GatherActor::Init() {
input_datas_num_ = data_nodes_.size();
input_device_tensors_.resize(input_datas_num_);
output_data_by_output_index_.resize(input_datas_num_);
for (auto &data_arrow : output_data_arrows_) {
MS_EXCEPTION_IF_NULL(data_arrow);
if (IntToSize(data_arrow->from_output_index_) >= input_datas_num_) {
MS_LOG(EXCEPTION) << "The output index is out of range: " << GetAID().Name();
}
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
(void)output_data_.emplace_back(data.get());
(void)output_data_by_output_index_[IntToSize(data_arrow->from_output_index_)].emplace_back(std::move(data));
}
}
size_t GatherActor::FetchDataNodePosition(const KernelWithIndex &data_node) const {
const auto &iter = find(data_nodes_.begin(), data_nodes_.end(), data_node);
if (iter == data_nodes_.end()) {
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node.first) << " index:" << data_node.second
<< " is not exist in gather actor:" << GetAID();
}
return iter - data_nodes_.begin();
}
void GatherActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
input_data_[sequential_num][input_data->index_].push(input_data->data_);
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
EraseInput(context);
SendOutput(context);
}
}
void GatherActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
(void)input_op_controls_[sequential_num].emplace_back(input_control);
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
EraseInput(context);
SendOutput(context);
}
}
void GatherActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_branch_ids_[sequential_num] = branch_id;
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
EraseInput(context);
SendOutput(context);
}
}
void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser) {
for (const auto &input : func_graph->get_inputs()) {
// Monad input would not send to gather actor.
if (HasAbstractMonad(input) ||
(input->isa<Parameter>() && AnfAlgo::IsParameterWeight(input->cast<ParameterPtr>()))) {
continue;
}
front_to_backend_parameter_[input] = parser->GetBackendInputByParameter(input);
}
}
void GatherActor::SendOutput(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
// Must be the execution order: send branch id --> send result --> send data --> send control, avoid the illegal
// timing problem.
// 1.Send output branch id.
if (find(output_branch_arrows_.begin(), output_branch_arrows_.end(), switch_aid_) != output_branch_arrows_.end()) {
int branch_id = input_branch_id_;
Async(switch_aid_, &SwitchActor::CollectBranchId, branch_id, context);
}
if (find(output_branch_arrows_.begin(), output_branch_arrows_.end(), gather_aid_) != output_branch_arrows_.end()) {
Async(gather_aid_, &GatherActor::CollectBranchId, local_branch_id_, context);
}
// 2.Send output result.
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
size_t from_index = IntToSize(result_arrow->from_output_index_);
const auto &front_node = data_nodes_[from_index].first;
for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) {
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second, false).get() ==
input_device_tensors_[from_index]) {
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
result_arrow->to_input_index_, context);
break;
}
}
}
// 3.Send output data.
for (auto &output_data : output_data_) {
MS_EXCEPTION_IF_NULL(output_data);
Async(output_data->op_id_, &OpActor::RunOpData, output_data, context);
}
// 4.Send output control.
auto source_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_control_arrows_) {
Async(output_control, &OpActor::RunOpControl, source_aid, context);
}
}
void GatherActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter != input_data_.end()) {
for (auto &input_data : data_iter->second) {
input_device_tensors_[input_data.first] = input_data.second.top();
input_data.second.pop();
}
}
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
const auto &device_context = device_contexts_[device_tensor_store_key.first];
MS_EXCEPTION_IF_NULL(device_context);
auto device_tensor =
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context->GetDeviceAddressType());
if (device_tensor == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() +
", device type:" + std::to_string(static_cast<int>(device_context->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
}
for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
const auto &data = input_device_tensors_[i];
for (auto &output_data : output_data_by_output_index_[i]) {
MS_EXCEPTION_IF_NULL(output_data);
output_data->data_ = data;
}
}
if (need_branch_id_input_) {
input_branch_id_ = input_branch_ids_[context->sequential_num_];
}
}
bool GatherActor::CheckLaunchCondition(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
// Fetch input data.
if (input_datas_num_ != 0) {
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter == input_data_.end()) {
return false;
}
if (data_iter->second.size() + device_tensor_store_keys_.size() != input_datas_num_) {
return false;
}
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
[](const auto &input_stack) { return input_stack.second.empty(); })) {
return false;
}
}
// Fetch input control.
if (input_controls_num_ != 0) {
auto control_iter = input_op_controls_.find(context->sequential_num_);
if (control_iter == input_op_controls_.end()) {
return false;
}
if (control_iter->second.size() != input_controls_num_) {
return false;
}
}
// Fetch input branch id.
if (need_branch_id_input_) {
auto branch_id_iter = input_branch_ids_.find(context->sequential_num_);
if (branch_id_iter == input_branch_ids_.end()) {
return false;
}
}
return true;
}
void GatherActor::EraseInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
// Erase input data.
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter != input_data_.end() && std::all_of(data_iter->second.begin(), data_iter->second.end(),
[](const auto &input_data) { return input_data.second.empty(); })) {
auto ret = input_data_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input data failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
// Erase input control.
if (input_controls_num_ != 0) {
auto ret = input_op_controls_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input controls failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
// Erase input branch id.
if (need_branch_id_input_) {
auto ret = input_branch_ids_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input branch id failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -1,143 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <unordered_map>
#include <stack>
#include <utility>
#include <algorithm>
#include "runtime/framework/device_tensor_store.h"
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/control_node_parser.h"
#include "runtime/hardware/device_context.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
#include "ir/tensor.h"
namespace mindspore {
namespace runtime {
constexpr size_t kReturnInputPos = 1;
// Gather actor is used in three places:
// 1. Entrance of sub funcgraph
// 2. call node which input0 is a funcgraph
// 3. There is some call nodes in the inputs of kernel graph.
// Gather actor will be used in the control flow. When the subgraph is called, the real parameters need to be put
// together and sent to the subgraph. At the same time, the entry of the subgraph needs to accept input data.
// Special in recursion, general inputs and call inputs of the kernel graph are used in stack mode, it needs to be
// collected at the entrance of the kernel graph.
class GatherActor : public OpActor<DeviceTensor> {
public:
GatherActor(const std::string &name, const std::vector<KernelWithIndex> &parameters, const bool need_branch_id_input,
const AID switch_aid, const AID gather_aid, const int branch_id)
: OpActor(name),
data_nodes_(parameters),
need_branch_id_input_(need_branch_id_input),
switch_aid_(switch_aid),
gather_aid_(gather_aid),
local_branch_id_(branch_id) {
device_contexts_.resize(parameters.size());
}
~GatherActor() override = default;
// Get the index of the parameter, the data_node needs to be the front node.
size_t FetchDataNodePosition(const KernelWithIndex &data_node) const;
// The gather actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override;
// The gather actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
// The gather actor run when receive the input branch id.
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context);
void Init() override;
private:
friend class GraphScheduler;
// Collect the inputs of gather actor.
void FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser);
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
// Check whether satisfy the condition for launch.
bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const;
void SendOutput(OpContext<DeviceTensor> *const context) const;
// Erase input data and input controls when finish gather launch.
void EraseInput(OpContext<DeviceTensor> *const context);
// The device tensors for launch.
std::vector<DeviceTensor *> input_device_tensors_;
// The branch if for current step.
int input_branch_id_{kInvalidBranchID};
// Input data.
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
// Input branch ids is used to record the id corresponding receive from gather actor.
// In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
// places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send
// its branch id to the gather actor of the subgraph. Then branch id will be sent by the gather actor to the
// switch actor connected to the output.
std::unordered_map<uuids::uuid *, int> input_branch_ids_;
// Output data.
// Cache unique output data by output index to modify the output data effectively.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<OpData<DeviceTensor> *> output_data_;
// Output arrows.
std::vector<DataArrowPtr> output_result_arrows_;
std::vector<AID> output_branch_arrows_;
// Parameters of sub funcgraph, which is the front node.
std::vector<KernelWithIndex> data_nodes_;
std::vector<DeviceContext *> device_contexts_;
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
// When the output is a parameter of the subgraph, the gather actor needs to send the anfnode to the output actor,
// so all the nodes that may send the device tensor to gather actor are recorded. When the anfnode needs to be sent
// to the output actor, the corresponding backend node will be found from the map.
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> front_to_backend_parameter_;
// The dependent input data number.
size_t input_datas_num_{0};
// The dependent input controls number.
size_t input_controls_num_{0};
// Whether it needs to accept the branch id. When the gather actor is the input of the subgraph, it needs to receive
// branch id sent by the subgraph caller, which will be true at this time.
bool need_branch_id_input_;
// Actor id that needs to send the branch id to it.
// When the actor is corresponding to call node, the branch id needs to be sent to the input gather actor and output
// switch actor of the called funcgraph. When the actor is the entrance of the funcgraph, the gather actor id is
// empty, just need to send branch id to its output switch actor.
const AID switch_aid_;
const AID gather_aid_;
// The branch id corresponding to the funcgraph to which the gather actor belongs.
int local_branch_id_;
};
using GatherActorPtr = std::shared_ptr<GatherActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_

View File

@ -1,501 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/gather_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "mindrt/include/async/async.h"
#include "abstract/utils.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace runtime {
void SwitchActor::Init() {
// Init output data.
output_data_.resize(output_branch_arrows_.size());
for (size_t i = 0; i < output_branch_arrows_.size(); ++i) {
auto &output_branch_arrow = output_branch_arrows_[i];
auto &output_data = output_data_[i];
for (auto &data_arrow : output_branch_arrow) {
MS_EXCEPTION_IF_NULL(data_arrow);
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
(void)output_data.emplace_back(std::move(data));
}
}
}
void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
const auto &sequential_num = context->sequential_num_;
auto &input_datas = input_data_[sequential_num];
input_datas[input_data->index_].push(input_data->data_);
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
EraseInput(context);
SendOutput(context);
}
}
void SwitchActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
if (input_controls_[sequential_num].find(input_control) == input_controls_[sequential_num].end()) {
input_controls_[sequential_num][input_control] = 0;
}
input_controls_[sequential_num][input_control]++;
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
EraseInput(context);
SendOutput(context);
}
}
void SwitchActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_branch_ids_[sequential_num].push(branch_id);
}
void SwitchActor::ParseInput(const ControlNodeParserPtr &parser) {
std::vector<AnfNodePtr> inputs = node_->inputs();
if (IsPrimitive(inputs[0], prim::kPrimSwitch)) {
ParseSwitchInput();
} else if (IsPrimitive(inputs[0], prim::kPrimReturn)) {
ParseReturnInput(parser);
} else {
ParseSwitchLayerInput();
}
backend_parameters_.resize(input_nodes_.size());
}
void SwitchActor::ParsePartialInput(const AnfNodePtr &node, const size_t branch_id) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
CNodePtr cnode = node->cast<CNodePtr>();
// The inputs of the Partial node is:
// [0] ValueNode<Primitive> kPartial.
// [1] ValueNode<FuncGraphPtr>.
// [2..] Inputs.
auto partial_inputs = cnode->inputs();
if (partial_inputs.size() <= kPartialFuncGraphPos) {
MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(cnode);
}
auto func_graph = GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
branch_func_graph_[branch_id] = func_graph;
for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
AddInput(partial_inputs[j], branch_id);
}
} else if (IsValueNode<FuncGraph>(node)) {
const auto func_graph = GetValueNode<FuncGraphPtr>(node);
branch_func_graph_[branch_id] = func_graph;
} else {
AddInput(node, branch_id);
}
}
void SwitchActor::InitVectorSize(const size_t num) {
branch_inputs_pos_.resize(num);
branch_func_graph_.resize(num);
output_branch_arrows_.resize(num);
output_branch_result_arrows_.resize(num);
output_branch_control_arrows_.resize(num);
output_branch_branch_arrows_.resize(num);
}
void SwitchActor::ParseReturnInput(const ControlNodeParserPtr &parser) {
const auto &func_graph = node_->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
const auto &call_num = parser->GetCallNumByFuncGraph(func_graph);
InitVectorSize(call_num);
AddCommonInput(func_graph->output());
}
void SwitchActor::ParseSwitchInput() {
// The inputs of the switch node:
// [0] ValueNode<Primitive> kSwitch.
// [1] switch condition.
// [2] Partial node: true branch.
// [3] Partial node: false branch.
std::vector<AnfNodePtr> inputs = node_->inputs();
if (inputs.size() != kSwitchInputNum) {
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is not equal 4";
}
InitVectorSize(kSwitchPartialNum);
const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0);
input_nodes_.push_back(cond_node);
input_datas_num_++;
// Init the two branches of switch node.
ParsePartialInput(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
ParsePartialInput(inputs[kSwitchTrueBranchPos], static_cast<size_t>(true));
}
void SwitchActor::ParseSwitchLayerInput() {
// The inputs of the switch node:
// [0] ValueNode<Primitive> kSwitchLayer.
// [1] switchLayer index.
// [2] MakeTuple node: tuple of branches.
std::vector<AnfNodePtr> inputs = node_->inputs();
if (inputs.size() != kSwitchLayerInputNum) {
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitchLayer->name() << " is not equal 3";
}
const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchLayerCondPos], 0);
input_nodes_.push_back(cond_node);
input_datas_num_++;
// The second input of SwitchLayer is maketuple node, which includes all branches.
auto branch_nodes = inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
InitVectorSize(branch_nodes.size() - 1);
// Parse all branches.
for (size_t i = kMakeTupleInputStartPos; i < branch_nodes.size(); ++i) {
if (AnfAlgo::CheckPrimitiveType(branch_nodes[i], prim::kPrimPartial)) {
ParsePartialInput(branch_nodes[i], i - kMakeTupleInputStartPos);
} else if (branch_nodes[i]->isa<ValueNode>()) {
branch_func_graph_[i - 1] = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
}
}
}
void SwitchActor::AddCommonInput(const AnfNodePtr &node) {
for (size_t i = 0; i < branch_inputs_pos_.size(); ++i) {
AddInput(node, i);
}
}
size_t SwitchActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
const auto data_node_with_index = AnfAlgo::VisitKernelWithReturnType(data_node, 0);
const auto &iter = find(input_nodes_.begin(), input_nodes_.end(), data_node_with_index);
if (iter == input_nodes_.end()) {
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node)
<< " is not exist in switch actor:" << GetAID();
}
return iter - input_nodes_.begin();
}
void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t branch) {
const auto &node = node_with_index.first;
// The value node and weight node need to be placed in the device store. The switch actor has three inputs:
// 1) The input of the switch is the value node.
// 2) There is a weight node or value node in the return of the sub funcgraph.
if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && node->isa<Parameter>() && HasAbstractRef(node)) ||
node->isa<ValueNode>()) {
const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
if (iter != input_nodes_.end()) {
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
return;
}
(void)device_tensor_store_keys_.emplace_back(input_nodes_.size(), node.get());
branch_inputs_pos_[branch].push_back(input_nodes_.size());
input_nodes_.push_back(node_with_index);
return;
}
// Output of updatestate node is U, need to be skipped.
if (node->isa<Parameter>() && HasAbstractRef(node)) {
return;
}
// Add parameter.
auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
if (iter == input_nodes_.end()) {
branch_inputs_pos_[branch].push_back(input_nodes_.size());
input_nodes_.push_back(node_with_index);
++input_datas_num_;
} else {
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
}
}
void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState) || HasAbstractMonad(node)) {
return;
}
const auto &real_input = AnfAlgo::VisitKernelWithReturnType(node, 0);
if (AnfAlgo::CheckPrimitiveType(real_input.first, prim::kPrimMakeTuple)) {
const auto &inputs = real_input.first->cast<CNodePtr>()->inputs();
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
AddInput(inputs[i], branch);
}
} else if (IsCallNode(real_input.first)) {
std::vector<AnfNodePtr> call_nodes;
const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes);
if (call_output_num == 0) {
MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first);
}
for (size_t i = 0; i < call_output_num; ++i) {
AddInput({real_input.first, i}, branch);
}
} else if (real_input.first->isa<ValueNode>() && real_input.first->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
const auto &value = real_input.first->cast<ValueNodePtr>()->value();
const auto &tuple_value = value->cast<ValueTuplePtr>();
for (size_t i = 0; i < tuple_value->value().size(); ++i) {
AddInput({real_input.first, i}, branch);
}
} else {
AddInput(real_input, branch);
}
}
size_t SwitchActor::GetIndex(const OpContext<DeviceTensor> *const context) {
if (need_branch_id_input_) {
if (input_branch_ids_.find(context->sequential_num_) == input_branch_ids_.end() ||
input_branch_ids_[context->sequential_num_].empty()) {
MS_LOG(ERROR) << "Invalid branch id for actor:" + GetAID().Name();
}
auto branch_id = input_branch_ids_[context->sequential_num_].top();
input_branch_ids_[context->sequential_num_].pop();
if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) {
MS_LOG(ERROR) << "Invalid branch id for switch actor:" + GetAID().Name() +
" branch id:" + std::to_string(branch_id);
}
return branch_id_to_index_[branch_id];
}
DeviceTensor *device_tensor = input_device_tensors_[0];
MS_EXCEPTION_IF_NULL(device_tensor);
auto inputs = node_->inputs();
TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0);
size_t size = abstract::TypeIdSize(type_id);
if (size > sizeof(int64_t)) {
MS_LOG(ERROR) << "Index must be Int type.";
}
int64_t index = 0;
char buf[kMaxSwitchCondSize] = {0};
ShapeVector host_shape;
if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) {
MS_LOG(ERROR) << GetAID().Name() << " get index from device address failed, type id:" << std::to_string(type_id)
<< ", device type:" << std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
}
if (type_id == TypeId::kNumberTypeInt32) {
index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
} else if (type_id == TypeId::kNumberTypeInt64) {
index = (static_cast<int64_t *>(static_cast<void *>(buf)))[0];
} else if (type_id == TypeId::kNumberTypeBool) {
bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
index = static_cast<int64_t>(cond ? 1 : 0);
} else {
MS_LOG(ERROR) << "Index must be Int type.";
}
// SwitchLayer node support negative index range [-size, -1].
if (index < 0) {
index += SizeToInt(branch_func_graph_.size());
}
return static_cast<size_t>(index);
}
bool SwitchActor::CheckLaunchCondition(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter == input_data_.end()) {
return false;
}
if (data_iter->second.size() != input_datas_num_) {
return false;
}
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
[](const auto &input_stack) { return input_stack.second.empty(); })) {
return false;
}
}
if (input_controls_num_ != 0) {
auto data_iter = input_controls_.find(context->sequential_num_);
if (data_iter == input_controls_.end()) {
return false;
}
if (data_iter->second.size() != input_controls_num_) {
return false;
}
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
[](const auto &input_stack) { return input_stack.second == 0; })) {
return false;
}
}
return true;
}
void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
input_device_tensors_.resize(input_nodes_.size());
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter != input_data_.end()) {
for (auto &input_data : data_iter->second) {
input_device_tensors_[input_data.first] = input_data.second.top();
input_data.second.pop();
}
}
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
auto device_tensor =
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
if (device_tensor == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
}
auto control_iter = input_controls_.find(context->sequential_num_);
if (control_iter != input_controls_.end()) {
(void)for_each(control_iter->second.begin(), control_iter->second.end(),
[](auto &input_control) { input_control.second--; });
}
}
void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto index = GetIndex(context);
if (index >= output_branch_arrows_.size()) {
std::string error_info = "Switch actor:" + GetAID().Name() + " invalid index:" + std::to_string(index);
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// Must be the execution order: send branch id --> send result --> send data --> send control, avoid the illegal
// timing problem.
// 1.Send branch id.
if (local_branch_id_ >= 0) {
const auto &branch_arrows = output_branch_branch_arrows_[index];
for (const auto &branch_arrow : branch_arrows) {
Async(branch_arrow, &GatherActor::CollectBranchId, local_branch_id_, context);
}
}
// 2.Send result.
auto &output_branch_result_arrow = output_branch_result_arrows_[index];
for (size_t i = 0; i < output_branch_result_arrow.size(); ++i) {
auto &result_arrow = output_branch_result_arrow[i];
MS_EXCEPTION_IF_NULL(result_arrow);
if (result_arrow->from_output_index_ >= SizeToInt(branch_inputs_pos_[index].size())) {
std::string error_info =
"Invalid from index in switch actor, from index:" + std::to_string(result_arrow->from_output_index_) +
" total:" + std::to_string(branch_inputs_pos_[index].size()) + " actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
size_t from_index = branch_inputs_pos_[index][IntToSize(result_arrow->from_output_index_)];
MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index];
bool is_send = false;
for (const auto &backend_node : backend_parameters_[from_index]) {
for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(backend_node.first); ++j) {
if (backend_node.first->kernel_info() != nullptr && AnfAlgo::OutputAddrExist(backend_node.first, j, false) &&
AnfAlgo::GetMutableOutputAddr(backend_node.first, j, false).get() == input_device_tensors_[from_index]) {
auto output_index = j;
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, output_index,
result_arrow->to_input_index_, context);
is_send = true;
MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index]
<< " succeed";
break;
}
}
}
if (!is_send) {
std::string error_info = "Failed to get backend node of switch actor output, actor:" + GetAID().Name() +
" branch:" + std::to_string(index) +
" index:" + std::to_string(result_arrow->from_output_index_) + " output pos" +
std::to_string(branch_inputs_pos_[index][IntToSize(result_arrow->from_output_index_)]) +
" output index" + std::to_string(result_arrow->to_input_index_);
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
// 3.Send Data.
auto &output_branch_arrow = output_branch_arrows_[index];
auto &output_data = output_data_[index];
for (size_t i = 0; i < output_branch_arrow.size(); ++i) {
auto &data_arrow = output_branch_arrow[i];
auto &data = output_data[i];
MS_EXCEPTION_IF_NULL(data_arrow);
MS_EXCEPTION_IF_NULL(data);
data->data_ = input_device_tensors_[IntToSize(data_arrow->from_output_index_)];
Async(data_arrow->to_op_id_, &OpActor::RunOpData, data.get(), context);
}
// 4.Send output control.
auto source_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_branch_control_arrows_[index]) {
Async(output_control, &OpActor::RunOpControl, source_aid, context);
}
}
void SwitchActor::EraseInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto data_iter = input_data_.find(context->sequential_num_);
if (data_iter != input_data_.end() && std::all_of(data_iter->second.begin(), data_iter->second.end(),
[](const auto &input_data) { return input_data.second.empty(); })) {
auto ret = input_data_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input data failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
if (input_controls_num_ != 0) {
auto control_iter = input_controls_.find(context->sequential_num_);
if (control_iter != input_controls_.end() &&
std::all_of(control_iter->second.begin(), control_iter->second.end(),
[](const auto &input_control) { return input_control.second == 0; })) {
auto ret = input_controls_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input control failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
}
}
void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context);
}
void SwitchActor::FetchInputNode(const ControlNodeParserPtr &parser) {
for (size_t i = 0; i < input_nodes_.size(); ++i) {
const auto &input_node = input_nodes_[i].first;
if (!(input_node->isa<Parameter>() && HasAbstractRef(input_node))) {
backend_parameters_[i] = parser->FetchBackendInputNodeByFrontNode(input_node);
continue;
}
const auto &backend_weight = parser->FetchBackendNodebyWeightNode(input_node);
if (backend_weight == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight node:" << AnfAlgo::GetNodeDebugString(input_node);
}
(void)backend_parameters_[i].emplace(backend_weight, 0);
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -1,180 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
#include <vector>
#include <string>
#include <set>
#include <memory>
#include <utility>
#include <stack>
#include <unordered_map>
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/device_tensor_store.h"
#include "runtime/framework/control_node_parser.h"
#include "mindrt/include/actor/switch_actor.h"
#include "runtime/hardware/device_context.h"
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::session::KernelWithIndex;
constexpr size_t kSwitchInputNum = 4;
constexpr size_t kSwitchCondPos = 1;
constexpr size_t kSwitchPartialNum = 2;
constexpr size_t kSwitchLayerCondPos = 1;
constexpr size_t kSwitchLayerBranchPos = 2;
constexpr size_t kSwitchLayerInputNum = 3;
constexpr size_t kMaxSwitchCondSize = 8;
constexpr size_t kSwitchTrueBranchPos = 2;
constexpr size_t kSwitchFalseBranchPos = 3;
constexpr size_t kPartialFuncGraphPos = 1;
constexpr size_t kPartialInputStartPos = 2;
constexpr size_t kCallInputStartPos = 1;
constexpr size_t kMakeTupleInputStartPos = 1;
// Switch actor is used to execute the branch according to the input condition.
// Switch and SwitchLayer node will be converted to switch actor.
// The execution process is divided into:
// 1. Put input into the vector.
// 2. Check whether the input condition has been received.
// 3. Check whether all input from the branch corresponding to the index has been received.
// 4. Send the data to the corresponding branch.
// 5. Free Memory
class SwitchActor : public SwitchActorBase<DeviceTensor> {
public:
SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node, const int branch_id,
const bool need_branch_id_input)
: SwitchActorBase(name),
device_context_(device_context),
node_(node),
local_branch_id_(branch_id),
need_branch_id_input_(need_branch_id_input) {}
~SwitchActor() override = default;
void Init() override;
// The switch actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
// The switch actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context);
// The switch actor run when receive the input branch id.
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context);
// Parse the input node information of the switch actor according to node_.
void ParseInput(const ControlNodeParserPtr &parser);
// Add input for all branches.
void AddCommonInput(const AnfNodePtr &node);
void AddSingleInput(const AnfNodePtr &node, size_t branch) { AddInput(node, branch); }
// Fetch the input position of the data node.
size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;
private:
friend class GraphScheduler;
void ParsePartialInput(const AnfNodePtr &node, const size_t branch_id);
void ParseSwitchInput();
void ParseSwitchLayerInput();
// In control flow, the output of each subgraph is connected to a switch actor, and the switch actor is
// initialized with the return node of the subgraph.
void ParseReturnInput(const ControlNodeParserPtr &parser);
// Initialize the size of the vector members.
void InitVectorSize(const size_t num);
// Get index from DeviceTensor.
size_t GetIndex(const OpContext<DeviceTensor> *const context);
// Add input for the branch.
void AddInput(const AnfNodePtr &node, size_t branch);
void AddInput(const KernelWithIndex node_with_index, const size_t branch);
// Check whether satisfy the condition for send outputs.
bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const;
// Fetch the args of switch branch.
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
void SendOutput(OpContext<DeviceTensor> *const context);
// Erase input data and input controls when finish switch launch.
void EraseInput(OpContext<DeviceTensor> *const context);
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context);
// Collect all the backend inputs of switch actor.
void FetchInputNode(const ControlNodeParserPtr &parser);
// All inputs of the switch actor, include weight and tensor.
// Used to receive input data, the first input is the condition of switch.
std::vector<KernelWithIndex> input_nodes_;
// The position of the branch output in the input_nodes_.
std::vector<std::vector<size_t>> branch_inputs_pos_;
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
std::unordered_map<uuids::uuid *, std::unordered_map<AID *, size_t>> input_controls_;
// Branch ids is used to record the id corresponding to the switch output branch.
// In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
// places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send
// its branch id to the gather of the subgraph. Then branch id will be sent by the gather actor to the switch
// actor connected to the output.
// In a recursive scenario, the switch will sequentially receive the branch ids sent by the caller, and the switch
// actor needs to store the branch ids in the stack, and pop up in turn when returning.
std::unordered_map<uuids::uuid *, std::stack<int>> input_branch_ids_;
// Control arrows of different branches.
std::vector<std::vector<AID>> output_branch_control_arrows_;
// Branch id arrows of different branches.
std::vector<std::vector<AID>> output_branch_branch_arrows_;
// Result arrows of different branches.
std::vector<std::vector<DataArrowPtr>> output_branch_result_arrows_;
// When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor,
// so all the nodes that may send the device tensor to switch actor are recorded.
std::vector<std::set<KernelWithIndex>> backend_parameters_;
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
std::vector<FuncGraphPtr> branch_func_graph_;
std::unordered_map<int, size_t> branch_id_to_index_;
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
std::vector<DeviceTensor *> input_device_tensors_;
// Save the DeviceContext of input_nodes_, which is used to release the DeviceTensor.
const DeviceContext *device_context_;
// The id of memory manager actor. Send message to it for alloc and free memory.
const AID memory_manager_aid_;
// The dependent input data number.
size_t input_datas_num_{0};
// The dependent input controls number.
size_t input_controls_num_{0};
CNodePtr node_;
// The branch id corresponding to the funcgraph to which the gather actor belongs.
int local_branch_id_;
// Whether it needs to accept the branch id. When the switch actor is the output of the subgraph, it needs to receive
// branch id sent by the gather actor of subgraph, which will be true at this time.
bool need_branch_id_input_;
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;
};
using SwitchActorPtr = std::shared_ptr<SwitchActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_

View File

@ -15,8 +15,8 @@
*/
#include "runtime/framework/control_node_parser.h"
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/framework/actor/gather_actor.h"
#include "runtime/framework/actor/control_flow/switch_actor.h"
#include "runtime/framework/actor/control_flow/gather_actor.h"
#include "abstract/utils.h"
#include "ir/tensor.h"

View File

@ -37,6 +37,17 @@ using mindspore::session::KernelWithIndex;
constexpr int kInvalidBranchID = -1;
constexpr int kMainBranchID = 0;
constexpr int kSubBranchStartID = 1;
constexpr size_t kSwitchInputNum = 4;
constexpr size_t kSwitchPartialNum = 2;
constexpr size_t kSwitchLayerCondPos = 1;
constexpr size_t kSwitchLayerBranchPos = 2;
constexpr size_t kSwitchLayerInputNum = 3;
constexpr size_t kSwitchTrueBranchPos = 2;
constexpr size_t kSwitchFalseBranchPos = 3;
constexpr size_t kPartialFuncGraphPos = 1;
constexpr size_t kPartialInputStartPos = 2;
constexpr size_t kCallInputStartPos = 1;
constexpr size_t kMakeTupleInputStartPos = 1;
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;

View File

@ -368,8 +368,6 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
actor_set->output_actor_ = BuildOutputActor(graph_compiler_info);
actor_set->data_prepare_actor_ =
BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
actor_set->switch_actors_ = BuildSwitchActor(graph_compiler_info);
actor_set->gather_actors_ = BuildGatherActor(graph_compiler_info);
return actor_set;
}
@ -732,155 +730,6 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorS
return no_input_kernel_actors;
}
std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
std::vector<SwitchActorPtr> switch_actors;
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
for (const auto &pair : front_node_to_actor_) {
front_to_backend_kernel[pair.first] = pair.second->kernel_;
}
// Build switch actor by switch node and switchlayer node.
for (const auto &control_node : graph_compiler_info.control_nodes_) {
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
const auto func_graph = control_node->func_graph();
const auto branch_id = graph_compiler_info.control_node_parser_->GetBranchIDByFuncGraph(func_graph);
const auto &actor_name = control_node->DebugString();
auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
control_node->cast<CNodePtr>(), branch_id, false);
switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
// Fetch all the input nodes of switch actor.
switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
InsertActor(switch_actor.get());
(void)switch_actors.emplace_back(switch_actor);
}
}
// Build switch actor by return node.
const auto func_graphs_to_call_num = graph_compiler_info.control_node_parser_->func_graph_to_call_num_;
for (const auto &func_graph_to_call_num : func_graphs_to_call_num) {
const auto &return_node = func_graph_to_call_num.first->get_return();
MS_EXCEPTION_IF_NULL(return_node);
const auto &actor_name = return_node->DebugString();
auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
return_node->cast<CNodePtr>(), kInvalidBranchID, true);
switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
// Fetch all the input nodes of switch actor.
switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
InsertActor(switch_actor.get());
(void)switch_actors.emplace_back(switch_actor);
}
return switch_actors;
}
std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
std::vector<GatherActorPtr> gather_actors;
const auto &loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor";
const auto &loop_count_actor = FetchActor(loop_count_actor_name);
if (loop_count_actor == nullptr) {
return gather_actors;
}
const auto &output_actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
const auto &output_actor = FetchActor(output_actor_name);
MS_EXCEPTION_IF_NULL(output_actor);
const auto parser = graph_compiler_info.control_node_parser_;
bool is_main_return = true;
// Each funcgraph has a return node, get the funcgraph from the return node, and create a gather actor.
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
for (const auto &pair : front_node_to_actor_) {
front_to_backend_kernel[pair.first] = pair.second->kernel_;
}
for (const auto &control_node : graph_compiler_info.control_nodes_) {
const auto &func_graph = control_node->func_graph();
const auto &cnode = control_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
const auto &return_node = func_graph->get_return();
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
// Root funcgraph does not need to create a gather actor.
if (is_main_return) {
is_main_return = false;
continue;
}
if (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) {
continue;
}
auto actor_name = func_graph->ToString();
std::vector<KernelWithIndex> parameters;
for (const auto &parameter : func_graph->get_inputs()) {
if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) {
continue;
}
(void)parameters.emplace_back(parameter, 0);
}
const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
const auto &output_switch_actor = FetchActor(return_node->DebugString());
MS_EXCEPTION_IF_NULL(output_switch_actor);
const auto &output_switch_aid = output_switch_actor->GetAID();
auto gather_actor =
std::make_shared<GatherActor>(actor_name, parameters, true, output_switch_aid, AID(), branch_id);
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
InsertActor(gather_actor.get());
(void)gather_actors.emplace_back(gather_actor);
}
}
// Create gather actor for call node which input0 of call node is a funcgraph.
for (const auto &control_node : graph_compiler_info.control_nodes_) {
const auto &cnode = control_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
// Collect the parameters.
std::vector<KernelWithIndex> parameters;
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]))) {
continue;
}
(void)parameters.emplace_back(inputs[i], 0);
}
auto func_graph = control_node->func_graph();
auto actor_name = control_node->DebugString();
const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
const auto &to_func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
const auto &to_actor = FetchActor(to_func_graph->ToString());
auto gather_actor =
std::make_shared<GatherActor>(actor_name, parameters, false, AID(), to_actor->GetAID(), branch_id);
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
InsertActor(gather_actor.get());
(void)gather_actors.emplace_back(gather_actor);
}
}
// Create gather actor for kernel graph which has a call input.
const auto &graph_with_device_contexts = graph_compiler_info.control_node_parser_->call_input_kernel_graphs_;
for (const auto &graph_with_device_context : graph_with_device_contexts) {
const auto &graph = graph_with_device_context.first;
const auto &parameters = FetchParameterbyKernelGraph(graph);
auto actor_name = graph->ToString();
auto gather_actor = std::make_shared<GatherActor>(actor_name, parameters, false, AID(), AID(), kInvalidBranchID);
InsertActor(gather_actor.get());
(void)gather_actors.emplace_back(gather_actor);
}
return gather_actors;
}
void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx) {
@ -1547,81 +1396,6 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
}
}
void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
const ActorSet *actor_set) {
const auto &to_actor = actor_set->output_actor_;
const auto &loop_count_actor = actor_set->loop_count_actor_;
if (to_actor == nullptr || loop_count_actor == nullptr) {
return;
}
// When there is a call node in the output, the output will be sent to the output actor by the switch actor of
// the funcgraph called by the call node.
const auto &outputs = graph_compiler_info.origin_outputs_order_;
for (const auto &output : outputs) {
const auto &output_node = output.first.first;
const auto &output_index = output.first.second;
const auto output_poses = output.second;
if (IsCallNode(output_node)) {
const auto &func_graphs = FetchFuncGraphbyCallNode(output_node);
for (const auto func_graph : func_graphs) {
const auto &actor_name = func_graph->get_return()->DebugString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
// Set branch index into switch actor.
size_t branch_index = switch_actor->branch_id_to_index_.size();
if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
branch_index = switch_actor->branch_id_to_index_[kMainBranchID];
} else {
switch_actor->branch_id_to_index_[kMainBranchID] = branch_index;
}
// Link output result arrow.
for (const auto output_pos : output_poses) {
auto op_arrow = std::make_shared<DataArrow>(output_index, to_actor->GetAID(), output_pos);
to_actor->device_contexts_[output_pos] = switch_actor->device_context_;
(void)switch_actor->output_branch_result_arrows_[branch_index].emplace_back(op_arrow);
}
}
}
}
const auto &switch_actors = actor_set->switch_actors_;
for (const auto &from_actor : switch_actors) {
MS_EXCEPTION_IF_NULL(from_actor);
auto origin_output_with_index = KernelWithIndex(from_actor->node_, 0);
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
continue;
}
// If the switch actor is in the output list, the output of switch actor should be sent to the output actor.
// And need to link a control arrow to the loop count actor.
for (const auto pos : iter->second) {
to_actor->device_contexts_[pos] = from_actor->device_context_;
}
for (size_t i = 0; i < from_actor->branch_inputs_pos_.size(); ++i) {
const auto &input_pos = from_actor->branch_inputs_pos_[i];
if (input_pos.empty()) {
MS_LOG(EXCEPTION) << "Invalid input num in switch actor:" << from_actor->GetAID();
}
for (const auto pos : iter->second) {
auto op_arrow = std::make_shared<DataArrow>(0, to_actor->GetAID(), pos);
(void)from_actor->output_branch_result_arrows_[i].emplace_back(op_arrow);
}
(void)from_actor->output_branch_control_arrows_[i].emplace_back(loop_count_actor->GetAID());
}
loop_count_actor->input_controls_num_++;
}
}
void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors) {
const size_t kNeedUpdateDeviceTensorStoreNum = 2;
for (auto &kernel_actor : auto_monad_actors) {
@ -1673,515 +1447,15 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
}
}
void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes) {
for (const auto &node : control_nodes) {
CNodePtr cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
// Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor.
if (inputs[0]->isa<CNode>()) {
auto actor = FetchActor(inputs[0]->DebugString());
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (HasAbstractMonad(inputs[i])) {
continue;
}
switch_actor->AddCommonInput(inputs[i]);
}
}
}
}
void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *const actor_set) {
PrepareInputNodeForSwitchActor(graph_compiler_info.control_nodes_);
for (const auto &node : graph_compiler_info.control_nodes_) {
CNodePtr cnode = node->cast<CNodePtr>();
const auto &from_func_graph = node->func_graph();
auto inputs = cnode->inputs();
// Link data arrow for switch node.
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) {
auto actor = actor_name_to_actor_[node->DebugString()];
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
} else if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
// Link the data arrow for the input of the call node.
const auto &actor_name = node->DebugString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(func_graph);
const auto &to_actor = FetchActor(func_graph->ToString());
MS_EXCEPTION_IF_NULL(to_actor);
size_t persist_input_num = 0;
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
MS_EXCEPTION_IF_NULL(actor);
if (inputs[i]->isa<ValueNode>()) {
const auto &node_value = inputs[i]->cast<ValueNodePtr>()->value();
if (!node_value->isa<tensor::Tensor>()) {
persist_input_num++;
continue;
}
(void)gather_actor->device_tensor_store_keys_.emplace_back(i - kCallInputStartPos - persist_input_num,
inputs[i].get());
gather_actor->device_contexts_[i - kCallInputStartPos - persist_input_num] =
graph_compiler_info.control_node_parser_->GetFrontValueNodeDeviceContext(inputs[i]);
} else if ((inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]->cast<ParameterPtr>())) ||
AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState) || HasAbstractMonad(inputs[i])) {
persist_input_num++;
continue;
} else {
const auto &input_with_index = AnfAlgo::VisitKernelWithReturnType(inputs[i], 0);
LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, actor,
i - kCallInputStartPos - persist_input_num);
}
auto op_arrow = std::make_shared<DataArrow>(i - kCallInputStartPos - persist_input_num, to_actor->GetAID(),
i - kCallInputStartPos - persist_input_num);
(void)gather_actor->output_data_arrows_.emplace_back(op_arrow);
}
}
}
// Link arrow for switch actor of subgraph output.
for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) {
const auto &func_graph = func_graph_with_call_num.first;
MS_EXCEPTION_IF_NULL(func_graph);
auto actor = FetchActor(func_graph->get_return()->DebugString());
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
}
// Link arrow for gather actor for call input kernel graph.
for (const auto &call_input_kernel_graph : graph_compiler_info.control_node_parser_->call_input_kernel_graphs_) {
const auto &kernel_graph = call_input_kernel_graph.first;
MS_EXCEPTION_IF_NULL(kernel_graph);
auto actor = FetchActor(kernel_graph->ToString());
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) {
const auto &input_with_index = gather_actor->data_nodes_[i];
const auto &from_func_graph = kernel_graph->GetFuncGraph();
LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor, i);
}
}
LinkBranchArrowForSwitchActor(graph_compiler_info);
LinkBranchArrowForGatherActor(graph_compiler_info);
LinkControlArrowForGatherActor(&(actor_set->kernel_actors_), graph_compiler_info.graphs_,
graph_compiler_info.control_node_parser_);
LinkControlArrowForSwitchActor(&(actor_set->switch_actors_), actor_set->loop_count_actor_.get(),
graph_compiler_info.origin_outputs_order_);
LinkOutputResultArrowForSwitchActor(graph_compiler_info, actor_set);
}
void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *const actor_set) {}
void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *const from_actor, KernelActor *const to_actor,
const KernelWithIndex &front_node_with_index,
const KernelWithIndex &to_node_with_index) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
auto position = from_actor->FetchDataNodePosition(front_node_with_index);
auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_node_with_index.second);
(void)from_actor->output_data_arrows_.emplace_back(op_arrow);
to_actor->input_datas_num_++;
}
void GraphScheduler::LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index,
const ControlNodeParserPtr &parser, const FuncGraphPtr &from_func_graph,
OpActor<DeviceTensor> *const to_actor, const size_t to_index) {
// Fetch all the funcgraph that call node would call.
const auto cnode = call_node_with_index.first->cast<CNodePtr>();
std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(cnode);
// Collect the output of each funcgraph.
for (const auto &func_graph : func_graphs) {
const auto actor_name = func_graph->get_return()->DebugString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
const size_t branch_index = switch_actor->branch_id_to_index_.size();
const auto &func_graph_to_branch_id = parser->func_graph_to_branch_id_;
const auto &iter = func_graph_to_branch_id.find(from_func_graph);
int branch_id = kMainBranchID;
if (iter != func_graph_to_branch_id.end()) {
branch_id = iter->second;
}
if (switch_actor->branch_id_to_index_.find(branch_id) != switch_actor->branch_id_to_index_.end()) {
LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index,
switch_actor->branch_id_to_index_[branch_id]);
continue;
}
LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index, branch_index);
switch_actor->branch_id_to_index_[branch_id] = branch_index;
}
}
const KernelWithIndex &to_node_with_index) {}
void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index,
OpActor<DeviceTensor> *to_actor, const size_t to_index,
const size_t branch_index) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
size_t start_branch = 0;
size_t max_branch = from_actor->output_branch_arrows_.size();
if (branch_index != SIZE_MAX) {
start_branch = branch_index;
max_branch = branch_index + 1;
}
for (size_t i = start_branch; i < max_branch; ++i) {
if (from_actor->branch_inputs_pos_[i].size() <= from_index) {
MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i
<< " from index:" << from_index << " output size:" << from_actor->branch_inputs_pos_[i].size()
<< " to actor:" << to_actor->GetAID() << " to index:" << to_index;
}
auto op_arrow =
std::make_shared<DataArrow>(from_actor->branch_inputs_pos_[i][from_index], to_actor->GetAID(), to_index);
(void)from_actor->output_branch_arrows_[i].emplace_back(op_arrow);
}
}
void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info,
const KernelWithIndex &input_with_index,
const FuncGraphPtr &from_func_graph,
OpActor<DeviceTensor> *const to_actor, const size_t to_index) {
const auto &parameters = graph_compiler_info.origin_parameters_order_;
const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
const auto &input_node = input_with_index.first;
if (IsCallNode(input_node)) {
// The actor input is a call node.
LinkDataArrowByCallInput(input_with_index, graph_compiler_info.control_node_parser_, from_func_graph, to_actor,
to_index);
} else if (IsGatherActor(input_node, actor_name_to_actor_)) {
// The actor input is a parameter in gather actor.
auto from_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[input_node->func_graph()->ToString()]);
auto position = from_actor->FetchDataNodePosition({input_node, 0});
auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_index);
(void)from_actor->output_data_arrows_.emplace_back(op_arrow);
} else if (IsSwitchActor(input_node)) {
const auto &actor_name = input_node->DebugString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
LinkDataArrowForSwitchActor(dynamic_cast<SwitchActor *>(actor), 0, to_actor, to_index);
} else if (IsKernelActor(input_node, graph_compiler_info.strategy_)) {
// The actor input is a cnode.
if (front_node_to_actor_.find(input_node) == front_node_to_actor_.end()) {
const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
const auto &backend_node =
graph_compiler_info.control_node_parser_->GetBackendKernelByFrontKernel(kernel_with_index);
if (backend_node.first == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find actor:" << to_actor->GetAID()
<< " input_node:" << AnfAlgo::GetNodeDebugString(input_node) << " addr:" << input_node;
}
const auto &actor_name = backend_node.first->fullname_with_scope();
const auto &actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto from_actor = dynamic_cast<KernelActor *>(actor);
MS_EXCEPTION_IF_NULL(from_actor);
auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index);
(void)from_actor->output_data_arrows_.emplace_back(op_arrow);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, backend_node.second, false);
UpdateRefCount(device_tensor.get(), true);
return;
}
auto op_arrow = std::make_shared<DataArrow>(input_with_index.second, to_actor->GetAID(), to_index);
auto from_actor = front_node_to_actor_[input_node];
(void)from_actor->output_data_arrows_.emplace_back(op_arrow);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, input_with_index.second, false);
UpdateRefCount(device_tensor.get(), true);
} else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) {
// The actor input is a parameter in host data source actor.
std::string actor_name = graph_compiler_info.name_ + "_HostDSActor";
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto from_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
MS_EXCEPTION_IF_NULL(from_actor);
auto backend_iter = front_to_backend_parameter.find(input_node);
if (backend_iter == front_to_backend_parameter.end()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(input_node);
}
const auto &backend_node = backend_iter->second.first;
auto iter = from_actor->data_node_position_map_.find(input_node);
if (iter == from_actor->data_node_position_map_.end()) {
MS_LOG(EXCEPTION) << "Cannot find data node in data source actor, backend node:"
<< AnfAlgo::GetNodeDebugString(backend_node)
<< " front node:" << AnfAlgo::GetNodeDebugString(input_node);
}
auto op_arrow = std::make_shared<DataArrow>(iter->second, to_actor->GetAID(), to_index);
(void)from_actor->output_data_arrows_.emplace_back(op_arrow);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->data_nodes_[iter->second], 0, false);
UpdateRefCount(device_tensor.get(), true);
} else {
MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node)
<< " to actor:" << to_actor->GetAID();
}
}
void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
SwitchActor *const actor) {
// Link switch input.
const auto &inputs = actor->input_nodes_;
for (size_t i = 0; i < inputs.size(); ++i) {
auto input = inputs[i];
if (input.first->isa<ValueNode>() || (input.first->isa<Parameter>() && HasAbstractRef(input.first))) {
continue;
}
const FuncGraphPtr from_func_graph = actor->node_->func_graph();
LinkDataArrowByControlNode(graph_compiler_info, input, from_func_graph, actor, i);
}
// Link switch output.
for (size_t i = 0; i < actor->branch_func_graph_.size(); ++i) {
auto func_graph = actor->branch_func_graph_[i];
if (func_graph == nullptr) {
continue;
}
auto gather_name = func_graph->ToString();
if (actor_name_to_actor_.find(gather_name) == actor_name_to_actor_.end()) {
MS_LOG(EXCEPTION) << "Cannot find gather actor for funcgraph:" << gather_name
<< ",switch input size:" << actor->input_nodes_.size();
}
auto to_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[gather_name]);
for (size_t j = 0; j < actor->branch_inputs_pos_[i].size(); ++j) {
auto pos = actor->branch_inputs_pos_[i][j];
auto to_actor_index = j;
auto op_arrow = std::make_shared<DataArrow>(pos, to_actor->GetAID(), to_actor_index);
(void)actor->output_branch_arrows_[i].emplace_back(op_arrow);
}
}
}
void GraphScheduler::LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *const kernel_actors,
const std::vector<KernelGraphPtr> &graphs,
const ControlNodeParserPtr &parser) {
// Link control arrow to kernel actor.
for (size_t i = 0; i < graphs.size(); ++i) {
const auto &kernel_graph = graphs[i];
MS_EXCEPTION_IF_NULL(kernel_graph);
const auto &func_graph = kernel_graph->GetFuncGraph();
if (func_graph == nullptr) {
continue;
}
const auto &actor = FetchActor(func_graph->ToString());
if (actor == nullptr) {
continue;
}
const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
// When gather actor is not empty, it means the control arrow of no input kernel actor needs to be sent by gather.
for (const auto &kernel : kernel_graph->execution_order()) {
if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) {
const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(kernel_actor);
if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
(void)gather_actor->output_control_arrows_.emplace_back(kernel_actor->GetAID());
kernel_actor->input_controls_num_ = 1;
}
}
}
}
for (auto &kernel_actor : *kernel_actors) {
MS_EXCEPTION_IF_NULL(kernel_actor);
if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
!parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) {
// Check whether the kernel actor belongs to the root graph.
// In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output
// should be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be
// sent to the output switch actor.
const auto &graph = kernel_actor->kernel_->func_graph();
OpActor<DeviceTensor> *actor = nullptr;
if (graph != nullptr) {
const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
const auto func_graph = kernel_graph->GetFuncGraph();
if (func_graph != nullptr) {
actor = FetchActor(func_graph->get_return()->DebugString());
if (actor != nullptr) {
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
(void)kernel_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
switch_actor->input_controls_num_++;
}
}
}
}
}
// Link input auto monad control arrow from kernel actor to gather actor.
const auto &monad_nodes = parser->kernel_to_call_nodes_;
for (const auto node_pair : monad_nodes) {
const auto &kernel_actor_name = node_pair.first->fullname_with_scope();
const auto &gather_actor_name = node_pair.second->DebugString();
auto kernel_op_actor = FetchActor(kernel_actor_name);
auto gather_op_actor = FetchActor(gather_actor_name);
if (kernel_op_actor == nullptr || gather_op_actor == nullptr) {
continue;
}
auto kernel_actor = dynamic_cast<KernelActor *>(kernel_op_actor);
auto gather_actor = dynamic_cast<GatherActor *>(gather_op_actor);
(void)kernel_actor->output_control_arrows_.emplace_back(gather_actor->GetAID());
gather_actor->input_controls_num_++;
}
}
void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *const switch_actors,
LoopCountActor *const to_actor,
const KernelMapPosition &origin_outputs_order) {
if (to_actor == nullptr || (*switch_actors).empty()) {
return;
}
// If there is no output from the switch actor branch, it means that the subgraph has no input,
// and need to connect a control arrow to the corresponding gather actor.
for (auto &switch_actor : (*switch_actors)) {
if (AnfAlgo::CheckPrimitiveType(switch_actor->node_, prim::kPrimReturn)) {
const auto &func_graph = switch_actor->node_->func_graph();
if (func_graph->output()->isa<ValueNode>()) {
const auto &actor_name = func_graph->ToString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
(void)gather_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
switch_actor->input_controls_num_++;
}
}
for (size_t i = 0; i < switch_actor->output_branch_arrows_.size(); ++i) {
const auto &arrows = switch_actor->output_branch_arrows_[i];
if (arrows.empty() && switch_actor->branch_func_graph_[i] != nullptr) {
const auto &actor_name = switch_actor->branch_func_graph_[i]->ToString();
const auto &actor = FetchActor(actor_name);
if (actor != nullptr) {
const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
(void)switch_actor->output_branch_control_arrows_[i].emplace_back(gather_actor->GetAID());
gather_actor->input_controls_num_++;
}
}
}
}
// Collect all the call node in outputs.
std::set<AnfNodePtr> call_nodes;
for (const auto &output : origin_outputs_order) {
if (IsCallNode(output.first.first)) {
(void)call_nodes.insert(output.first.first);
}
}
to_actor->input_controls_num_ += call_nodes.size();
// Link the output switch actor of the subgraph to the output actor.
for (const auto &call_node : call_nodes) {
const auto &func_graphs = FetchFuncGraphbyCallNode(call_node);
for (const auto func_graph : func_graphs) {
MS_EXCEPTION_IF_NULL(func_graph);
const auto &actor_name = func_graph->get_return()->DebugString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
size_t branch_index = switch_actor->branch_id_to_index_.size();
if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
branch_index = switch_actor->branch_id_to_index_[kMainBranchID];
} else {
switch_actor->branch_id_to_index_[kMainBranchID] = branch_index;
}
(void)switch_actor->output_branch_control_arrows_[branch_index].emplace_back(to_actor->GetAID());
}
}
}
void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
for (const auto &control_node : graph_compiler_info.control_nodes_) {
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
const auto &actor_name = control_node->DebugString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) {
const auto &func_graph = switch_actor->branch_func_graph_[i];
if (func_graph == nullptr) {
continue;
}
const auto &gather_actor = FetchActor(func_graph->ToString());
MS_EXCEPTION_IF_NULL(gather_actor);
(void)switch_actor->output_branch_branch_arrows_[i].emplace_back(gather_actor->GetAID());
}
}
}
}
void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info) {
if (graph_compiler_info.control_nodes_.empty()) {
return;
}
// Link branch arrow from gather actor to gather actor.
for (const auto &control_node : graph_compiler_info.control_nodes_) {
const auto &cnode = control_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
const auto &actor_name = control_node->DebugString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
(void)gather_actor->output_branch_arrows_.emplace_back(gather_actor->gather_aid_);
}
}
// Link branch arrow from gather actor to switch actor.
for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) {
const auto &actor_name = func_graph_with_call_num.first->ToString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
(void)gather_actor->output_branch_arrows_.emplace_back(gather_actor->switch_aid_);
}
}
const size_t branch_index) {}
bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionStrategy strategy) const {
MS_EXCEPTION_IF_NULL(actor_set);
@ -2681,90 +1955,11 @@ void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compil
void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
ofs << "\t\tactor input num:" << actor->data_nodes_.size() << "\n";
for (const auto &node : actor->data_nodes_) {
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << "\tindex:" << node.second << '\n';
}
ofs << "\t\tactor front to backend node:\n";
for (const auto &front_to_backend_parameter : actor->front_to_backend_parameter_) {
ofs << "\t\t\tfront node:" << AnfAlgo::GetNodeDebugString(front_to_backend_parameter.first) << '\n';
for (const auto node_with_index : front_to_backend_parameter.second) {
ofs << "\t\t\t\tbackend node:" << AnfAlgo::GetNodeDebugString(node_with_index.first)
<< "\tindex:" << node_with_index.second << '\n';
}
}
ofs << "\t\tactor output data arrow:\n";
for (const auto &data_arrow : actor->output_data_arrows_) {
MS_EXCEPTION_IF_NULL(data_arrow);
ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
<< "\n";
}
ofs << "\t\tactor output result arrow:\n";
for (const auto &result_arrow : actor->output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name() << "\tto_input_index:" << result_arrow->to_input_index_
<< "\n";
}
ofs << "\t\tactor output control arrow:\n";
for (const auto &control_arrow : actor->output_control_arrows_) {
ofs << "\t\t\tto_actor_name:" << control_arrow;
}
ofs << "\n";
}
void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
ofs << "\t\tactor input num:" << actor->input_nodes_.size() << "\n";
for (const auto &node : actor->input_nodes_) {
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << '\t' << node.second << '\n';
}
ofs << "\t\tactor input pos:\n";
for (size_t i = 0; i < actor->branch_inputs_pos_.size(); ++i) {
ofs << "\t\t\tbranch " << i << " input pos:";
for (const auto pos : actor->branch_inputs_pos_[i]) {
ofs << pos << '\t';
}
ofs << '\n';
}
ofs << "\t\tactor output data arrow:\n";
for (size_t i = 0; i < actor->output_branch_arrows_.size(); ++i) {
ofs << "\t\t\tbranch " << i << " output data:\n";
for (const auto arrow : actor->output_branch_arrows_[i]) {
MS_EXCEPTION_IF_NULL(arrow);
ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
<< "\tto_input_index:" << arrow->to_input_index_ << '\n';
}
}
ofs << "\t\tactor output result arrow:\n";
for (size_t i = 0; i < actor->output_branch_result_arrows_.size(); ++i) {
ofs << "\t\t\tbranch " << i << " output result:\n";
for (const auto arrow : actor->output_branch_result_arrows_[i]) {
MS_EXCEPTION_IF_NULL(arrow);
ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
<< "\tto_input_index:" << arrow->to_input_index_ << '\n';
}
}
ofs << "\t\tactor output control arrow:\n";
for (size_t i = 0; i < actor->output_branch_control_arrows_.size(); ++i) {
ofs << "\t\t\tbranch " << i << " output control:\n";
for (const auto arrow : actor->output_branch_control_arrows_[i]) {
ofs << "\t\t\t\t from index:" << arrow << '\n';
}
}
ofs << "\n";
}
} // namespace runtime
} // namespace mindspore

View File

@ -33,9 +33,9 @@
#include "runtime/framework/actor/loop_count_actor.h"
#include "runtime/framework/actor/kernel_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/framework/actor/gather_actor.h"
#include "runtime/framework/actor/copy_actor.h"
#include "runtime/framework/actor/control_flow/switch_actor.h"
#include "runtime/framework/actor/control_flow/gather_actor.h"
#include "thread/actor_threadpool.h"
namespace mindspore {
@ -131,8 +131,6 @@ class GraphScheduler {
const std::vector<DataSourceActorPtr> &data_source_actors,
const HostTensorQueuePtr &host_queue);
std::vector<KernelActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);
std::vector<SwitchActorPtr> BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info);
std::vector<GatherActorPtr> BuildGatherActor(const GraphCompilerInfo &graph_compiler_info);
// Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
// previous graph and the head of next graph.
@ -195,34 +193,10 @@ class GraphScheduler {
void LinkDataArrowForGatherActor(GatherActor *const from_actor, KernelActor *const to_actor,
const KernelWithIndex &front_node_with_index,
const KernelWithIndex &to_node_with_index);
void LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, SwitchActor *const actor);
// Connect the input of the actor.
void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const KernelWithIndex &input_node,
const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *const to_actor,
const size_t to_index);
// When the input of the actor is a call node, the output of the funcgraph called by the call node needs to be
// connected.
void LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index, const ControlNodeParserPtr &parser,
const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *const to_actor,
const size_t to_index);
void LinkDataArrowForSwitchActor(SwitchActor *const from_actor, const size_t from_index,
OpActor<DeviceTensor> *const to_actor, const size_t to_index,
const size_t branch_index = SIZE_MAX);
void LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *const kernel_actors,
const std::vector<KernelGraphPtr> &graphs, const ControlNodeParserPtr &parser);
void LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *const switch_actors, LoopCountActor *const to_actor,
const KernelMapPosition &origin_outputs_order);
// In control flow, there are scenarios where there are multi-branch outputs, and the gather actor needs to
// send the branch id to the loop count actor.
void LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info);
void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info);
void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
// Add input for switch actor. Since part of the input of funcgraph is on call node, these inputs need to be added
// to switch actor.
void PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes);
// Check whether the actor set is valid.
bool CheckActorValid(const ActorSet *actor_set,
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline) const;