!15795 Add Switch Actor.

From: @gaoyong10
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-05-15 10:23:41 +08:00 committed by Gitee
commit 7c393c0375
8 changed files with 393 additions and 53 deletions

View File

@ -19,6 +19,7 @@
#ifdef __WIN32__
#include <windows.h>
#endif
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace runtime {
@ -34,5 +35,44 @@ int64_t GetMaxThreadNum() {
return max_thread_num;
}
bool IsDeviceQueueDSActor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) == kGetNextOpName)) {
return true;
}
return false;
}
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
// Judge whether node is internal parameter.
if (graph->GetFrontNodeByInternalParameter(node) == nullptr) {
return true;
}
}
return false;
}
bool IsKernelActor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) != kGetNextOpName)) {
return true;
}
return false;
}
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>()) {
return true;
}
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
return true;
}
return false;
}
} // namespace runtime
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <utility>
#include "mindrt/include/actor/op_actor.h"
#include "backend/session/kernel_graph.h"
#include "utils/log_adapter.h"
namespace mindspore {
@ -43,6 +44,12 @@ constexpr int kFailure = 1;
// Get the max available thread number of system.
int64_t GetMaxThreadNum();
bool IsDeviceQueueDSActor(const AnfNodePtr &node);
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph);
bool IsKernelActor(const AnfNodePtr &node);
// Judge whether the device tensor of the node is persistent or not.
bool IsPersistentDeviceTensor(const AnfNodePtr &node);
} // namespace runtime
} // namespace mindspore

View File

@ -20,18 +20,6 @@
namespace mindspore {
namespace runtime {
namespace {
// Judge whether the device tensor of the node is persistent or not.
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>()) {
return true;
}
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
return true;
}
return false;
}
TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position) {
MS_EXCEPTION_IF_NULL(output_node);
MS_LOG(INFO) << "Create output tensor, output node: " << output_node->fullname_with_scope()

View File

@ -0,0 +1,236 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "mindrt/include/async/async.h"
#include "abstract/utils.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace runtime {
constexpr size_t kSwitchInputNum = 4;
constexpr size_t kSwitchCondPos = 1;
constexpr size_t kSwitchPartialNum = 2;
constexpr size_t kSwitchLayerCondPos = 1;
constexpr size_t kSwitchLayerBranchPos = 2;
constexpr size_t kSwitchLayerInputNum = 3;
constexpr size_t kMaxSwitchCondSize = 8;
constexpr size_t kSwitchTrueBranchPos = 2;
constexpr size_t kSwitchFalseBranchPos = 3;
constexpr size_t kPartialFuncGraphPos = 1;
constexpr size_t kPartialInputStartPos = 2;
void SwitchActor::RunOpData(OpDataPtr<DeviceTensor> input_data, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
input_op_datas_[sequential_num].emplace_back(input_data);
// When all the inputs are collected, then allocate memory and callback launch.
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
SendOutput(context);
}
}
void SwitchActor::Initialize() {
std::vector<AnfNodePtr> inputs = node_->inputs();
if (IsPrimitive(inputs[0], prim::kPrimSwitch)) {
InitSwitch();
} else {
InitSwitchLayer();
}
input_datas_num_ = input_nodes_.size();
}
void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
CNodePtr cnode = node->cast<CNodePtr>();
// The inputs of the Partial node is:
// [0] ValueNode<Primitive> kPartial.
// [1] ValueNode<FuncGraphPtr>.
// [2..] Inputs.
auto node_inputs = cnode->inputs();
branch_func_graph_[branch_id] = GetValueNode<FuncGraphPtr>(node_inputs[kPartialFuncGraphPos]);
for (size_t j = kPartialInputStartPos; j < node_inputs.size(); ++j) {
AddInput(node_inputs[j], branch_id);
}
} else {
AddInput(node, branch_id);
}
}
void SwitchActor::InitSwitch() {
// The inputs of the switch node:
// [0] ValueNode<Primitive> kSwitch.
// [1] switch condition.
// [2] Partial node: true branch.
// [3] Partial node: false branch.
std::vector<AnfNodePtr> inputs = node_->inputs();
if (inputs.size() != kSwitchInputNum) {
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is not equal 4";
}
branch_total_inputs_.resize(kSwitchPartialNum);
branch_inputs_pos_.resize(kSwitchPartialNum);
branch_func_graph_.resize(kSwitchPartialNum);
output_branch_arrows_.resize(kSwitchPartialNum);
input_nodes_.push_back(inputs[kSwitchCondPos]);
// Init the two branches of switch node.
InitPartial(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
InitPartial(inputs[kSwitchTrueBranchPos], static_cast<size_t>(true));
}
void SwitchActor::InitSwitchLayer() {
// The inputs of the switch node:
// [0] ValueNode<Primitive> kSwitchLayer.
// [1] switchLayer index.
// [2] MakeTuple node: tuple of branches.
std::vector<AnfNodePtr> inputs = node_->inputs();
if (inputs.size() != kSwitchLayerInputNum) {
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitchLayer->name() << " is not equal 3";
}
input_nodes_.push_back(inputs[kSwitchLayerCondPos]);
// The second input of SwitchLayer is maketuple node, which includes all branches.
auto branch_nodes = inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
branch_total_inputs_.resize(branch_nodes.size() - 1);
branch_inputs_pos_.resize(branch_nodes.size() - 1);
branch_device_tensor_store_keys_.resize(branch_nodes.size() - 1);
branch_func_graph_.resize(branch_nodes.size() - 1);
output_branch_arrows_.resize(branch_nodes.size() - 1);
// Parse all branches.
for (size_t i = 1; i < branch_nodes.size(); ++i) {
if (AnfAlgo::CheckPrimitiveType(branch_nodes[i], prim::kPrimPartial)) {
InitPartial(branch_nodes[i], i - 1);
} else if (branch_nodes[i]->isa<ValueNode>()) {
branch_func_graph_[i - 1] = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
}
}
}
void SwitchActor::AddCommonInput(const AnfNodePtr &node) {
for (size_t i = 0; i < branch_inputs_pos_.size(); ++i) {
AddInput(node, i);
}
}
void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
branch_total_inputs_[branch].push_back(node);
if (IsPersistentDeviceTensor(node)) {
return;
}
auto iter = find(input_nodes_.begin(), input_nodes_.end(), node);
if (iter == input_nodes_.end()) {
branch_inputs_pos_[branch].push_back(input_nodes_.size());
input_nodes_.push_back(node);
} else {
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
}
}
size_t SwitchActor::GetIndex() {
DeviceTensor *device_tensor = input_device_tensors_[0];
auto inputs = node_->inputs();
TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0);
size_t size = abstract::TypeIdSize(type_id);
if (size > sizeof(int64_t)) {
MS_LOG(EXCEPTION) << "Index must be Int type.";
}
int64_t index = 0;
char buf[kMaxSwitchCondSize] = {0};
ShapeVector host_shape;
device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf));
if (type_id == TypeId::kNumberTypeInt32) {
index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
} else if (type_id == TypeId::kNumberTypeInt64) {
index = (static_cast<int64_t *>(static_cast<void *>(buf)))[0];
} else if (type_id == TypeId::kNumberTypeBool) {
bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
index = static_cast<int64_t>(cond ? 1 : 0);
} else {
MS_LOG(EXCEPTION) << "Index must be Int type.";
}
// SwitchLayer node support negative index range [-size, -1].
if (index < 0) {
index += branch_func_graph_.size();
}
if (index > static_cast<int64_t>(SIZE_MAX)) {
MS_LOG(EXCEPTION) << "Index is too large:" << index;
}
return static_cast<size_t>(index);
}
bool SwitchActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter == input_op_datas_.end()) {
return false;
}
if (data_iter->second.size() != input_datas_num_) {
return false;
}
}
return true;
}
void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto input_size = input_datas_num_ + branch_device_tensor_store_keys_.size();
input_device_tensors_.resize(input_size);
auto data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end()) {
for (auto &input_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(input_data);
input_device_tensors_[input_data->index_] = input_data->data_;
}
}
data_iter->second.clear();
for (auto &device_tensor_store_key : branch_device_tensor_store_keys_) {
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second);
input_device_tensors_[device_tensor_store_key.first] = device_tensor.get();
}
}
void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto index = GetIndex();
if (index >= output_branch_arrows_.size()) {
MS_LOG(EXCEPTION) << "Switch actor invalid index:" << index;
}
for (const auto &arrow : output_branch_arrows_[index]) {
auto device_address = input_device_tensors_[arrow->from_output_index_];
auto data = std::make_shared<OpData<DeviceTensor>>(arrow->to_op_id_, device_address, arrow->to_input_index_);
Async(arrow->to_op_id_, &OpActor::RunOpData, data, context);
}
}
void SwitchActor::FreeMemory(OpContext<DeviceTensor> *context) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, input_device_tensors_, device_contexts_, context);
}
} // namespace runtime
} // namespace mindspore

View File

@ -0,0 +1,99 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include <unordered_map>
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/device_tensor_store.h"
#include "mindrt/include/actor/switch_actor.h"
#include "runtime/hardware/device_context.h"
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
// Switch actor is used to execute the branch according to the input condition.
// Switch and SwitchLayer node will be converted to switch actor.
// The execution process is divided into:
// 1. Put input into the vector.
// 2. Check whether the input condition has been received.
// 3. Check whether all input from the branch corresponding to the index has been received.
// 4. Send the data to the corresponding branch.
// 5. Free Memory
class SwitchActor : public SwitchActorBase<DeviceTensor> {
public:
SwitchActor(const std::string &name, const CNodePtr &node) : SwitchActorBase(name), node_(node) {}
~SwitchActor() override = default;
// The switch actor run when receive the input data.
void RunOpData(OpDataPtr<DeviceTensor> input_data, OpContext<DeviceTensor> *context);
// Initialize the input and output information of the switch actor According to node_.
void Initialize();
// Add input for all branches.
void AddCommonInput(const AnfNodePtr &node);
private:
friend class GraphScheduler;
void InitPartial(const AnfNodePtr &node, const size_t branch_id);
void InitSwitch();
void InitSwitchLayer();
// Get index from DeviceTensor.
size_t GetIndex();
// Add input for the branch.
void AddInput(const AnfNodePtr &node, size_t branch);
// Check whether satisfy the condition for send outputs.
bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const;
// Fetch the args of switch branch.
void FetchInputDeviceTensor(OpContext<DeviceTensor> *context);
void SendOutput(OpContext<DeviceTensor> *context);
void FreeMemory(OpContext<DeviceTensor> *context);
// All inputs of the switch actor, excluding weight and tensor.
// Used to receive input data, the first input is the condition of switch.
std::vector<AnfNodePtr> input_nodes_;
// The position of the branch output in the input_nodes_.
std::vector<std::vector<size_t>> branch_inputs_pos_;
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, void *>> branch_device_tensor_store_keys_;
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
std::vector<FuncGraphPtr> branch_func_graph_;
std::vector<DeviceTensor *> input_device_tensors_;
// Save the DeviceContext of input_nodes_, which is used to release the DeviceTensor.
DeviceContext *device_contexts_;
// The id of memory manager actor. Send message to it for alloc and free memory.
const AID memory_manager_aid_;
// The dependent input data number.
size_t input_datas_num_{0};
CNodePtr node_;
};
using SwitchActorPtr = std::shared_ptr<SwitchActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_

View File

@ -28,47 +28,8 @@
namespace mindspore {
namespace runtime {
namespace {
bool IsDeviceQueueDSActor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) == kGetNextOpName)) {
return true;
}
return false;
}
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
// Judge whether node is internal parameter.
if (graph->GetFrontNodeByInternalParameter(node) == nullptr) {
return true;
}
}
return false;
}
bool IsKernelActor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) != kGetNextOpName)) {
return true;
}
return false;
}
// Judge whether the device tensor of the node is persistent or not.
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>()) {
return true;
}
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
return true;
}
return false;
}
KernelActor *FindKernelActor(const KernelMapActor &kernel_actors_map, const std::string &name) {
auto iter = kernel_actors_map.find(name);
if (iter != kernel_actors_map.end()) {
@ -366,6 +327,13 @@ void GraphScheduler::Schedule(const ActorSet *actor_set) {
(void)actorMgr->Spawn(base_actor);
}
// Schedule switch actors.
for (auto &switch_actor : actor_set->switch_actors_) {
MS_EXCEPTION_IF_NULL(switch_actor);
auto base_actor = static_cast<ActorReference>(switch_actor);
(void)actorMgr->Spawn(base_actor);
}
// Schedule loop count actor.
if (actor_set->loop_count_actor_ != nullptr) {
auto base_actor = static_cast<ActorReference>(actor_set->loop_count_actor_);

View File

@ -29,6 +29,7 @@
#include "runtime/framework/actor/loop_count_actor.h"
#include "runtime/framework/actor/kernel_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/hardware/device_context.h"
#include "backend/session/kernel_graph.h"
@ -87,6 +88,7 @@ struct ActorSet {
std::vector<KernelActorPtr> kernel_actors_;
// No input kernel actors need be triggered specifically.
std::vector<KernelActorPtr> no_input_kernel_actors_;
std::vector<SwitchActorPtr> switch_actors_;
LoopCountActorPtr loop_count_actor_{nullptr};
OutputActorPtr output_actor_{nullptr};
ActorInfo name_;

View File

@ -28,7 +28,7 @@ namespace mindspore {
template <typename T>
class SwitchActorBase : public OpActor<T> {
public:
explicit SwitchActorBase(std::string op_name) : ActorBase(op_name) {}
explicit SwitchActorBase(std::string op_name) : OpActor<T>(op_name) {}
virtual ~SwitchActorBase() = default;
// The actor run when receive the input data.