commit
7c393c0375
|
@ -19,6 +19,7 @@
|
|||
#ifdef __WIN32__
|
||||
#include <windows.h>
|
||||
#endif
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
@ -34,5 +35,44 @@ int64_t GetMaxThreadNum() {
|
|||
return max_thread_num;
|
||||
}
|
||||
|
||||
bool IsDeviceQueueDSActor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) == kGetNextOpName)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
|
||||
// Judge whether node is internal parameter.
|
||||
if (graph->GetFrontNodeByInternalParameter(node) == nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsKernelActor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) != kGetNextOpName)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<ValueNode>()) {
|
||||
return true;
|
||||
}
|
||||
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <utility>
|
||||
#include "mindrt/include/actor/op_actor.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -43,6 +44,12 @@ constexpr int kFailure = 1;
|
|||
// Get the max available thread number of system.
|
||||
int64_t GetMaxThreadNum();
|
||||
|
||||
bool IsDeviceQueueDSActor(const AnfNodePtr &node);
|
||||
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph);
|
||||
bool IsKernelActor(const AnfNodePtr &node);
|
||||
|
||||
// Judge whether the device tensor of the node is persistent or not.
|
||||
bool IsPersistentDeviceTensor(const AnfNodePtr &node);
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -20,18 +20,6 @@
|
|||
namespace mindspore {
|
||||
namespace runtime {
|
||||
namespace {
|
||||
// Judge whether the device tensor of the node is persistent or not.
|
||||
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<ValueNode>()) {
|
||||
return true;
|
||||
}
|
||||
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position) {
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
MS_LOG(INFO) << "Create output tensor, output node: " << output_node->fullname_with_scope()
|
||||
|
|
|
@ -0,0 +1,236 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "runtime/framework/actor/memory_manager_actor.h"
|
||||
#include "mindrt/include/async/async.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
constexpr size_t kSwitchInputNum = 4;
|
||||
constexpr size_t kSwitchCondPos = 1;
|
||||
constexpr size_t kSwitchPartialNum = 2;
|
||||
constexpr size_t kSwitchLayerCondPos = 1;
|
||||
constexpr size_t kSwitchLayerBranchPos = 2;
|
||||
constexpr size_t kSwitchLayerInputNum = 3;
|
||||
constexpr size_t kMaxSwitchCondSize = 8;
|
||||
constexpr size_t kSwitchTrueBranchPos = 2;
|
||||
constexpr size_t kSwitchFalseBranchPos = 3;
|
||||
constexpr size_t kPartialFuncGraphPos = 1;
|
||||
constexpr size_t kPartialInputStartPos = 2;
|
||||
|
||||
void SwitchActor::RunOpData(OpDataPtr<DeviceTensor> input_data, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
|
||||
// When all the inputs are collected, then allocate memory and callback launch.
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::Initialize() {
|
||||
std::vector<AnfNodePtr> inputs = node_->inputs();
|
||||
|
||||
if (IsPrimitive(inputs[0], prim::kPrimSwitch)) {
|
||||
InitSwitch();
|
||||
} else {
|
||||
InitSwitchLayer();
|
||||
}
|
||||
input_datas_num_ = input_nodes_.size();
|
||||
}
|
||||
|
||||
void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
|
||||
// The inputs of the Partial node is:
|
||||
// [0] ValueNode<Primitive> kPartial.
|
||||
// [1] ValueNode<FuncGraphPtr>.
|
||||
// [2..] Inputs.
|
||||
auto node_inputs = cnode->inputs();
|
||||
branch_func_graph_[branch_id] = GetValueNode<FuncGraphPtr>(node_inputs[kPartialFuncGraphPos]);
|
||||
for (size_t j = kPartialInputStartPos; j < node_inputs.size(); ++j) {
|
||||
AddInput(node_inputs[j], branch_id);
|
||||
}
|
||||
} else {
|
||||
AddInput(node, branch_id);
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::InitSwitch() {
|
||||
// The inputs of the switch node:
|
||||
// [0] ValueNode<Primitive> kSwitch.
|
||||
// [1] switch condition.
|
||||
// [2] Partial node: true branch.
|
||||
// [3] Partial node: false branch.
|
||||
std::vector<AnfNodePtr> inputs = node_->inputs();
|
||||
if (inputs.size() != kSwitchInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is not equal 4";
|
||||
}
|
||||
|
||||
branch_total_inputs_.resize(kSwitchPartialNum);
|
||||
branch_inputs_pos_.resize(kSwitchPartialNum);
|
||||
branch_func_graph_.resize(kSwitchPartialNum);
|
||||
output_branch_arrows_.resize(kSwitchPartialNum);
|
||||
input_nodes_.push_back(inputs[kSwitchCondPos]);
|
||||
|
||||
// Init the two branches of switch node.
|
||||
InitPartial(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
|
||||
InitPartial(inputs[kSwitchTrueBranchPos], static_cast<size_t>(true));
|
||||
}
|
||||
|
||||
void SwitchActor::InitSwitchLayer() {
|
||||
// The inputs of the switch node:
|
||||
// [0] ValueNode<Primitive> kSwitchLayer.
|
||||
// [1] switchLayer index.
|
||||
// [2] MakeTuple node: tuple of branches.
|
||||
std::vector<AnfNodePtr> inputs = node_->inputs();
|
||||
if (inputs.size() != kSwitchLayerInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitchLayer->name() << " is not equal 3";
|
||||
}
|
||||
|
||||
input_nodes_.push_back(inputs[kSwitchLayerCondPos]);
|
||||
|
||||
// The second input of SwitchLayer is maketuple node, which includes all branches.
|
||||
auto branch_nodes = inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
|
||||
branch_total_inputs_.resize(branch_nodes.size() - 1);
|
||||
branch_inputs_pos_.resize(branch_nodes.size() - 1);
|
||||
branch_device_tensor_store_keys_.resize(branch_nodes.size() - 1);
|
||||
branch_func_graph_.resize(branch_nodes.size() - 1);
|
||||
output_branch_arrows_.resize(branch_nodes.size() - 1);
|
||||
|
||||
// Parse all branches.
|
||||
for (size_t i = 1; i < branch_nodes.size(); ++i) {
|
||||
if (AnfAlgo::CheckPrimitiveType(branch_nodes[i], prim::kPrimPartial)) {
|
||||
InitPartial(branch_nodes[i], i - 1);
|
||||
} else if (branch_nodes[i]->isa<ValueNode>()) {
|
||||
branch_func_graph_[i - 1] = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::AddCommonInput(const AnfNodePtr &node) {
|
||||
for (size_t i = 0; i < branch_inputs_pos_.size(); ++i) {
|
||||
AddInput(node, i);
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
|
||||
branch_total_inputs_[branch].push_back(node);
|
||||
if (IsPersistentDeviceTensor(node)) {
|
||||
return;
|
||||
}
|
||||
auto iter = find(input_nodes_.begin(), input_nodes_.end(), node);
|
||||
if (iter == input_nodes_.end()) {
|
||||
branch_inputs_pos_[branch].push_back(input_nodes_.size());
|
||||
input_nodes_.push_back(node);
|
||||
} else {
|
||||
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
|
||||
}
|
||||
}
|
||||
|
||||
size_t SwitchActor::GetIndex() {
|
||||
DeviceTensor *device_tensor = input_device_tensors_[0];
|
||||
auto inputs = node_->inputs();
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0);
|
||||
size_t size = abstract::TypeIdSize(type_id);
|
||||
if (size > sizeof(int64_t)) {
|
||||
MS_LOG(EXCEPTION) << "Index must be Int type.";
|
||||
}
|
||||
|
||||
int64_t index = 0;
|
||||
char buf[kMaxSwitchCondSize] = {0};
|
||||
ShapeVector host_shape;
|
||||
device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf));
|
||||
|
||||
if (type_id == TypeId::kNumberTypeInt32) {
|
||||
index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
|
||||
} else if (type_id == TypeId::kNumberTypeInt64) {
|
||||
index = (static_cast<int64_t *>(static_cast<void *>(buf)))[0];
|
||||
} else if (type_id == TypeId::kNumberTypeBool) {
|
||||
bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
|
||||
index = static_cast<int64_t>(cond ? 1 : 0);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Index must be Int type.";
|
||||
}
|
||||
|
||||
// SwitchLayer node support negative index range [-size, -1].
|
||||
if (index < 0) {
|
||||
index += branch_func_graph_.size();
|
||||
}
|
||||
if (index > static_cast<int64_t>(SIZE_MAX)) {
|
||||
MS_LOG(EXCEPTION) << "Index is too large:" << index;
|
||||
}
|
||||
return static_cast<size_t>(index);
|
||||
}
|
||||
|
||||
bool SwitchActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter == input_op_datas_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_datas_num_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto input_size = input_datas_num_ + branch_device_tensor_store_keys_.size();
|
||||
input_device_tensors_.resize(input_size);
|
||||
auto data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter != input_op_datas_.end()) {
|
||||
for (auto &input_data : data_iter->second) {
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
input_device_tensors_[input_data->index_] = input_data->data_;
|
||||
}
|
||||
}
|
||||
data_iter->second.clear();
|
||||
|
||||
for (auto &device_tensor_store_key : branch_device_tensor_store_keys_) {
|
||||
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second);
|
||||
input_device_tensors_[device_tensor_store_key.first] = device_tensor.get();
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto index = GetIndex();
|
||||
if (index >= output_branch_arrows_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Switch actor invalid index:" << index;
|
||||
}
|
||||
for (const auto &arrow : output_branch_arrows_[index]) {
|
||||
auto device_address = input_device_tensors_[arrow->from_output_index_];
|
||||
auto data = std::make_shared<OpData<DeviceTensor>>(arrow->to_op_id_, device_address, arrow->to_input_index_);
|
||||
Async(arrow->to_op_id_, &OpActor::RunOpData, data, context);
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::FreeMemory(OpContext<DeviceTensor> *context) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, input_device_tensors_, device_contexts_, context);
|
||||
}
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "mindrt/include/actor/switch_actor.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
|
||||
// Switch actor is used to execute the branch according to the input condition.
|
||||
// Switch and SwitchLayer node will be converted to switch actor.
|
||||
// The execution process is divided into:
|
||||
// 1. Put input into the vector.
|
||||
// 2. Check whether the input condition has been received.
|
||||
// 3. Check whether all input from the branch corresponding to the index has been received.
|
||||
// 4. Send the data to the corresponding branch.
|
||||
// 5. Free Memory
|
||||
class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
||||
public:
|
||||
SwitchActor(const std::string &name, const CNodePtr &node) : SwitchActorBase(name), node_(node) {}
|
||||
~SwitchActor() override = default;
|
||||
|
||||
// The switch actor run when receive the input data.
|
||||
void RunOpData(OpDataPtr<DeviceTensor> input_data, OpContext<DeviceTensor> *context);
|
||||
// Initialize the input and output information of the switch actor According to node_.
|
||||
void Initialize();
|
||||
// Add input for all branches.
|
||||
void AddCommonInput(const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
void InitPartial(const AnfNodePtr &node, const size_t branch_id);
|
||||
void InitSwitch();
|
||||
void InitSwitchLayer();
|
||||
|
||||
// Get index from DeviceTensor.
|
||||
size_t GetIndex();
|
||||
// Add input for the branch.
|
||||
void AddInput(const AnfNodePtr &node, size_t branch);
|
||||
|
||||
// Check whether satisfy the condition for send outputs.
|
||||
bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const;
|
||||
// Fetch the args of switch branch.
|
||||
void FetchInputDeviceTensor(OpContext<DeviceTensor> *context);
|
||||
void SendOutput(OpContext<DeviceTensor> *context);
|
||||
void FreeMemory(OpContext<DeviceTensor> *context);
|
||||
|
||||
// All inputs of the switch actor, excluding weight and tensor.
|
||||
// Used to receive input data, the first input is the condition of switch.
|
||||
std::vector<AnfNodePtr> input_nodes_;
|
||||
// The position of the branch output in the input_nodes_.
|
||||
std::vector<std::vector<size_t>> branch_inputs_pos_;
|
||||
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
|
||||
std::vector<std::pair<size_t, void *>> branch_device_tensor_store_keys_;
|
||||
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
|
||||
std::vector<FuncGraphPtr> branch_func_graph_;
|
||||
|
||||
std::vector<DeviceTensor *> input_device_tensors_;
|
||||
|
||||
// Save the DeviceContext of input_nodes_, which is used to release the DeviceTensor.
|
||||
DeviceContext *device_contexts_;
|
||||
|
||||
// The id of memory manager actor. Send message to it for alloc and free memory.
|
||||
const AID memory_manager_aid_;
|
||||
// The dependent input data number.
|
||||
size_t input_datas_num_{0};
|
||||
CNodePtr node_;
|
||||
};
|
||||
|
||||
using SwitchActorPtr = std::shared_ptr<SwitchActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
|
|
@ -28,47 +28,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
||||
namespace {
|
||||
bool IsDeviceQueueDSActor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) == kGetNextOpName)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
|
||||
// Judge whether node is internal parameter.
|
||||
if (graph->GetFrontNodeByInternalParameter(node) == nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsKernelActor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) != kGetNextOpName)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Judge whether the device tensor of the node is persistent or not.
|
||||
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<ValueNode>()) {
|
||||
return true;
|
||||
}
|
||||
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
KernelActor *FindKernelActor(const KernelMapActor &kernel_actors_map, const std::string &name) {
|
||||
auto iter = kernel_actors_map.find(name);
|
||||
if (iter != kernel_actors_map.end()) {
|
||||
|
@ -366,6 +327,13 @@ void GraphScheduler::Schedule(const ActorSet *actor_set) {
|
|||
(void)actorMgr->Spawn(base_actor);
|
||||
}
|
||||
|
||||
// Schedule switch actors.
|
||||
for (auto &switch_actor : actor_set->switch_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
auto base_actor = static_cast<ActorReference>(switch_actor);
|
||||
(void)actorMgr->Spawn(base_actor);
|
||||
}
|
||||
|
||||
// Schedule loop count actor.
|
||||
if (actor_set->loop_count_actor_ != nullptr) {
|
||||
auto base_actor = static_cast<ActorReference>(actor_set->loop_count_actor_);
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "runtime/framework/actor/loop_count_actor.h"
|
||||
#include "runtime/framework/actor/kernel_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
|
@ -87,6 +88,7 @@ struct ActorSet {
|
|||
std::vector<KernelActorPtr> kernel_actors_;
|
||||
// No input kernel actors need be triggered specifically.
|
||||
std::vector<KernelActorPtr> no_input_kernel_actors_;
|
||||
std::vector<SwitchActorPtr> switch_actors_;
|
||||
LoopCountActorPtr loop_count_actor_{nullptr};
|
||||
OutputActorPtr output_actor_{nullptr};
|
||||
ActorInfo name_;
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace mindspore {
|
|||
template <typename T>
|
||||
class SwitchActorBase : public OpActor<T> {
|
||||
public:
|
||||
explicit SwitchActorBase(std::string op_name) : ActorBase(op_name) {}
|
||||
explicit SwitchActorBase(std::string op_name) : OpActor<T>(op_name) {}
|
||||
virtual ~SwitchActorBase() = default;
|
||||
|
||||
// The actor run when receive the input data.
|
||||
|
|
Loading…
Reference in New Issue