!23970 Add control flow actor.
Merge pull request !23970 from gaoyong10/runtime_second
This commit is contained in:
commit
3f81a92ed4
|
@ -56,7 +56,13 @@ enum class KernelTransformType {
|
|||
kOutputActor,
|
||||
kDeviceTensorStore,
|
||||
// Internal parameter is the output of previous kernel graph which is related to the input of next kernel graph.
|
||||
kInternalParameter
|
||||
kInternalParameter,
|
||||
// Control flow actor type.
|
||||
kSwitchActor,
|
||||
kGatherActor,
|
||||
kEntranceActor,
|
||||
kExitActor,
|
||||
kStackActor
|
||||
};
|
||||
|
||||
#define SET_OPCONTEXT_FAIL_RET_WITH_ERROR(op_context, message) \
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_ENTRANCE_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_ENTRANCE_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <stack>
|
||||
#include <queue>
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/actor/abstract_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
// Entrance actor is used in the control flow to receive a set of result arrow and a branch id and then send
|
||||
// the data to the corresponding actor. It is the entry point for subgraph execution.
|
||||
class EntranceActor : public AbstractActor {
|
||||
public:
|
||||
EntranceActor(const std::string &name, const std::vector<AnfNodePtr> ¶meters)
|
||||
: AbstractActor(name, KernelTransformType::kEntranceActor, nullptr), formal_parameters_(parameters) {
|
||||
device_contexts_.resize(parameters.size());
|
||||
}
|
||||
~EntranceActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
||||
// The entrance actor run when receive the input control.
|
||||
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
|
||||
// The entrance actor run when receive the real parameter nodes and branch id.
|
||||
void CollectRealParametersAndBranchId(const std::vector<KernelWithIndex> &real_parameters, int branch_id,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) const;
|
||||
|
||||
// Formal parameters of actor, which is the front node.
|
||||
std::vector<KernelWithIndex> formal_parameters_;
|
||||
|
||||
// Input data.
|
||||
std::unordered_map<uuids::uuid *, std::queue<std::vector<KernelWithIndex>>> input_nodes_;
|
||||
std::unordered_map<uuids::uuid *, std::queue<int>> input_branch_ids_;
|
||||
|
||||
std::vector<AID> output_branch_id_arrows_;
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
std::vector<OpData<DeviceTensor> *> output_data_;
|
||||
bool is_actor_ready_{true};
|
||||
};
|
||||
|
||||
using EntranceActorPtr = std::shared_ptr<EntranceActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_ENTRANCE_ACTOR_H_
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_EXIT_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_EXIT_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <stack>
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/actor/abstract_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
// The exit actor is used to receive a set of result arrow and a branch id in the control flow, and then send the
|
||||
// node in the result to the corresponding actor. It is the exit of the end of subgraph execution.
|
||||
class ExitActor : public AbstractActor {
|
||||
public:
|
||||
ExitActor(const std::string &name, const std::vector<AnfNodePtr> ¶meters)
|
||||
: AbstractActor(name, KernelTransformType::kExitActor, nullptr), formal_parameters_(parameters) {}
|
||||
~ExitActor() override = default;
|
||||
|
||||
// The exit actor run when receive the anfnode.
|
||||
void CollectRealParameter(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
// The exit actor run when receive the input branch id.
|
||||
void CollectBranchId(int branch_id, OpContext<DeviceTensor> *const context);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) const;
|
||||
|
||||
// Formal parameters of actor, which is the front node.
|
||||
std::vector<KernelWithIndex> formal_parameters_;
|
||||
|
||||
// Input data.
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, KernelWithIndex>> input_nodes_;
|
||||
// Branch ids is used to record the id corresponding to the output branch.
|
||||
// In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
|
||||
// places. Therefore, the output of each subgraph will be connected to a exit actor, and the caller will send
|
||||
// its branch id to the entrance actor of the subgraph. Then branch id will be sent by the entrance actor to
|
||||
// the exit actor connected to the output.
|
||||
// In a recursive scenario, the exit will sequentially receive the branch ids sent by the caller, and the exit
|
||||
// actor needs to store the branch ids in the stack, and pop up in turn when returning.
|
||||
std::unordered_map<uuids::uuid *, std::stack<int>> input_branch_ids_;
|
||||
|
||||
// Output arrow.
|
||||
std::unordered_map<int, std::vector<DataArrowPtr>> output_branch_result_arrows_;
|
||||
};
|
||||
|
||||
using ExitActorPtr = std::shared_ptr<ExitActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_EXIT_ACTOR_H_
|
|
@ -0,0 +1,21 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/framework/actor/control_flow/gather_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_GATHER_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_GATHER_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <stack>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/actor/abstract_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
||||
// Gather actor will be used in the control flow. When the subgraph is called, the real parameters need to be put
|
||||
// together and sent to the subgraph.
|
||||
class GatherActor : public AbstractActor {
|
||||
public:
|
||||
GatherActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters)
|
||||
: AbstractActor(name, KernelTransformType::kGatherActor, nullptr), formal_parameters_(parameters) {}
|
||||
~GatherActor() override = default;
|
||||
|
||||
// The gather actor collects single node when receive the result of kernel actor.
|
||||
void CollectRealParameter(const AnfNodePtr &node, size_t index, size_t position,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
// The gather actor collects all real parameters when receive the output of switch actor.
|
||||
void CollectRealParameters(const std::vector<KernelWithIndex> &real_parameters, size_t position,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) const;
|
||||
|
||||
// Formal parameters of actor, which is the front node.
|
||||
std::vector<KernelWithIndex> formal_parameters_;
|
||||
|
||||
// Input data.
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::vector<KernelWithIndex>>> input_nodes_;
|
||||
// The store node records the value node input of the gather actor.
|
||||
std::vector<std::pair<size_t, KernelWithIndex>> store_nodes_;
|
||||
|
||||
// Output arrow.
|
||||
std::unordered_map<AnfNodePtr, std::pair<AID, size_t>> output_branch_arrows_;
|
||||
};
|
||||
|
||||
using GatherActorPtr = std::shared_ptr<GatherActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_GATHER_ACTOR_H_
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <stack>
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/actor/abstract_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
// Stack actor is used to record those device actors that need additional storage in recursive scenes.
|
||||
class StackActor : public MemoryAwareActor {
|
||||
public:
|
||||
StackActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters)
|
||||
: AbstractActor(name, KernelTransformType::kStackActor, nullptr), formal_parameters_(parameters) {
|
||||
device_contexts_.resize(parameters.size());
|
||||
}
|
||||
~StackActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
||||
// The stack actor run when receive the real parameter nodes.
|
||||
void CollectRealParameter(const AnfNodePtr &node, size_t index, size_t position,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) const;
|
||||
|
||||
// Formal parameters record the input front-end node, these nodes may be parameter, kernel, call node.
|
||||
std::vector<KernelWithIndex> formal_parameters_;
|
||||
|
||||
// The backend parameter is used to save the backend node corresponding to the device tensor in the stack.
|
||||
// When these device tensors are used as output, they need to be placed in the node of the result arrow,
|
||||
// so these nodes need to be saved.
|
||||
std::vector<KernelWithIndex> backend_parameters_;
|
||||
|
||||
// Input data.
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, KernelWithIndex>> input_nodes_;
|
||||
|
||||
// The input data records that the stack actor is copied from the input nodes and needs to be stored in the
|
||||
// device tensor in the stack. This part of the device tensor does not belong to any node, and it will be
|
||||
// cleaned up directly after the stack is popped.
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
|
||||
};
|
||||
|
||||
using StackActorPtr = std::shared_ptr<StackActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/framework/actor/control_flow/switch_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/gather_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/memory_manager_actor.h"
|
||||
#include "mindrt/include/async/async.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void SwitchActor::Init() {
|
||||
// Init output data.
|
||||
output_data_.resize(output_branch_data_arrows_.size());
|
||||
for (size_t i = 0; i < output_branch_data_arrows_.size(); ++i) {
|
||||
auto &output_branch_data_arrows = output_branch_data_arrows_[i];
|
||||
auto &output_data = output_data_[i];
|
||||
for (auto &data_arrow : output_branch_data_arrows) {
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
|
||||
(void)output_data.emplace_back(std::move(data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t SwitchActor::GetIndex(const OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
const auto &nodes_iter = input_nodes_.find(context->sequential_num_);
|
||||
if (nodes_iter == input_nodes_.end()) {
|
||||
MS_LOG(ERROR) << "Cannot find input node for switch actor:" << GetAID();
|
||||
return 0;
|
||||
}
|
||||
const auto &index_iter = nodes_iter->second.find(0);
|
||||
if (index_iter == nodes_iter->second.end() || index_iter->second.empty()) {
|
||||
MS_LOG(ERROR) << "Cannot find index input node for switch actor:" << GetAID();
|
||||
return 0;
|
||||
}
|
||||
|
||||
const auto &index_node_with_index = index_iter->second[0];
|
||||
const auto &index_node = index_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
MS_EXCEPTION_IF_NULL(index_node->kernel_info());
|
||||
if (!AnfAlgo::OutputAddrExist(index_node, index_node_with_index.second, false)) {
|
||||
MS_LOG(ERROR) << "Invalid output index:" << index_node_with_index.second
|
||||
<< " for node:" << index_node->DebugString();
|
||||
return 0;
|
||||
}
|
||||
|
||||
DeviceTensor *device_tensor = AnfAlgo::GetMutableOutputAddr(index_node, index_node_with_index.second, false).get();
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(index_node, index_node_with_index.second);
|
||||
size_t size = abstract::TypeIdSize(type_id);
|
||||
if (size > sizeof(int64_t)) {
|
||||
MS_LOG(ERROR) << "Index must be Int type.";
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t index = 0;
|
||||
char buf[kMaxSwitchCondSize] = {0};
|
||||
ShapeVector host_shape;
|
||||
if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) {
|
||||
MS_LOG(ERROR) << GetAID().Name() << " get index from device address failed, type id:" << std::to_string(type_id)
|
||||
<< ", device type:" << std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceAddressType()));
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (type_id == TypeId::kNumberTypeInt32) {
|
||||
index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
|
||||
} else if (type_id == TypeId::kNumberTypeInt64) {
|
||||
index = (static_cast<int64_t *>(static_cast<void *>(buf)))[0];
|
||||
} else if (type_id == TypeId::kNumberTypeBool) {
|
||||
bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
|
||||
index = static_cast<int64_t>(cond ? 1 : 0);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Index must be Int type.";
|
||||
return 0;
|
||||
}
|
||||
|
||||
// SwitchLayer node support negative index range [-size, -1].
|
||||
if (index < 0) {
|
||||
index += SizeToInt(input_result_num_ - 1);
|
||||
}
|
||||
return static_cast<size_t>(index);
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,80 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_SWITCH_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_SWITCH_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/actor/abstract_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
using mindspore::session::KernelWithIndex;
|
||||
|
||||
constexpr size_t kSwitchCondPos = 1;
|
||||
constexpr size_t kMaxSwitchCondSize = 8;
|
||||
|
||||
// Switch actor is used to execute the branch according to the input condition.
|
||||
// Switch and SwitchLayer node will be converted to switch actor.
|
||||
class SwitchActor : public AbstractActor {
|
||||
public:
|
||||
SwitchActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters)
|
||||
: AbstractActor(name, KernelTransformType::kSwitchActor, nullptr), formal_parameters_(parameters) {
|
||||
input_result_num_ = formal_parameters_.size();
|
||||
}
|
||||
~SwitchActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
||||
// The switch actor collects single node when receive the result of kernel actor.
|
||||
void CollectRealParameter(const AnfNodePtr &node, size_t index, size_t position,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
// The switch actor collects all real parameters when receive the output of gather actor.
|
||||
void CollectRealParameters(const std::vector<KernelWithIndex> &real_parameters, size_t position,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
size_t GetIndex(const OpContext<DeviceTensor> *const context);
|
||||
|
||||
// Formal parameters of actor, which is the front node.
|
||||
std::vector<KernelWithIndex> formal_parameters_;
|
||||
|
||||
// Input data.
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::vector<KernelWithIndex>>> input_nodes_;
|
||||
// The store node records the value node input of the switch actor.
|
||||
std::vector<std::pair<size_t, AnfNodePtr>> store_nodes_;
|
||||
|
||||
// Output arrow.
|
||||
std::vector<std::vector<DataArrowPtr>> output_branch_data_arrows_;
|
||||
std::vector<std::vector<DataArrowPtr>> output_branch_result_arrows_;
|
||||
std::vector<AID> output_branch_real_parameter_arrows_;
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;
|
||||
size_t input_result_num_;
|
||||
};
|
||||
|
||||
using SwitchActorPtr = std::shared_ptr<SwitchActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_SWITCH_ACTOR_H_
|
|
@ -1,250 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/framework/actor/gather_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "runtime/framework/actor/memory_manager_actor.h"
|
||||
#include "runtime/framework/actor/loop_count_actor.h"
|
||||
#include "mindrt/include/async/async.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void GatherActor::Init() {
|
||||
input_datas_num_ = data_nodes_.size();
|
||||
input_device_tensors_.resize(input_datas_num_);
|
||||
output_data_by_output_index_.resize(input_datas_num_);
|
||||
|
||||
for (auto &data_arrow : output_data_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
if (IntToSize(data_arrow->from_output_index_) >= input_datas_num_) {
|
||||
MS_LOG(EXCEPTION) << "The output index is out of range: " << GetAID().Name();
|
||||
}
|
||||
|
||||
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
|
||||
(void)output_data_.emplace_back(data.get());
|
||||
(void)output_data_by_output_index_[IntToSize(data_arrow->from_output_index_)].emplace_back(std::move(data));
|
||||
}
|
||||
}
|
||||
|
||||
size_t GatherActor::FetchDataNodePosition(const KernelWithIndex &data_node) const {
|
||||
const auto &iter = find(data_nodes_.begin(), data_nodes_.end(), data_node);
|
||||
if (iter == data_nodes_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node.first) << " index:" << data_node.second
|
||||
<< " is not exist in gather actor:" << GetAID();
|
||||
}
|
||||
return iter - data_nodes_.begin();
|
||||
}
|
||||
|
||||
void GatherActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
input_data_[sequential_num][input_data->index_].push(input_data->data_);
|
||||
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
}
|
||||
|
||||
void GatherActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
(void)input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
}
|
||||
|
||||
void GatherActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
input_branch_ids_[sequential_num] = branch_id;
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
}
|
||||
|
||||
void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser) {
|
||||
for (const auto &input : func_graph->get_inputs()) {
|
||||
// Monad input would not send to gather actor.
|
||||
if (HasAbstractMonad(input) ||
|
||||
(input->isa<Parameter>() && AnfAlgo::IsParameterWeight(input->cast<ParameterPtr>()))) {
|
||||
continue;
|
||||
}
|
||||
front_to_backend_parameter_[input] = parser->GetBackendInputByParameter(input);
|
||||
}
|
||||
}
|
||||
|
||||
void GatherActor::SendOutput(OpContext<DeviceTensor> *const context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// Must be the execution order: send branch id --> send result --> send data --> send control, avoid the illegal
|
||||
// timing problem.
|
||||
// 1.Send output branch id.
|
||||
if (find(output_branch_arrows_.begin(), output_branch_arrows_.end(), switch_aid_) != output_branch_arrows_.end()) {
|
||||
int branch_id = input_branch_id_;
|
||||
Async(switch_aid_, &SwitchActor::CollectBranchId, branch_id, context);
|
||||
}
|
||||
if (find(output_branch_arrows_.begin(), output_branch_arrows_.end(), gather_aid_) != output_branch_arrows_.end()) {
|
||||
Async(gather_aid_, &GatherActor::CollectBranchId, local_branch_id_, context);
|
||||
}
|
||||
|
||||
// 2.Send output result.
|
||||
for (const auto &result_arrow : output_result_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
size_t from_index = IntToSize(result_arrow->from_output_index_);
|
||||
const auto &front_node = data_nodes_[from_index].first;
|
||||
for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) {
|
||||
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second, false).get() ==
|
||||
input_device_tensors_[from_index]) {
|
||||
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
|
||||
result_arrow->to_input_index_, context);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3.Send output data.
|
||||
for (auto &output_data : output_data_) {
|
||||
MS_EXCEPTION_IF_NULL(output_data);
|
||||
Async(output_data->op_id_, &OpActor::RunOpData, output_data, context);
|
||||
}
|
||||
|
||||
// 4.Send output control.
|
||||
auto source_aid = const_cast<AID *>(&GetAID());
|
||||
for (auto &output_control : output_control_arrows_) {
|
||||
Async(output_control, &OpActor::RunOpControl, source_aid, context);
|
||||
}
|
||||
}
|
||||
|
||||
void GatherActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter != input_data_.end()) {
|
||||
for (auto &input_data : data_iter->second) {
|
||||
input_device_tensors_[input_data.first] = input_data.second.top();
|
||||
input_data.second.pop();
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
|
||||
const auto &device_context = device_contexts_[device_tensor_store_key.first];
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
auto device_tensor =
|
||||
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context->GetDeviceAddressType());
|
||||
if (device_tensor == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() +
|
||||
", device type:" + std::to_string(static_cast<int>(device_context->GetDeviceAddressType()));
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
|
||||
const auto &data = input_device_tensors_[i];
|
||||
for (auto &output_data : output_data_by_output_index_[i]) {
|
||||
MS_EXCEPTION_IF_NULL(output_data);
|
||||
output_data->data_ = data;
|
||||
}
|
||||
}
|
||||
|
||||
if (need_branch_id_input_) {
|
||||
input_branch_id_ = input_branch_ids_[context->sequential_num_];
|
||||
}
|
||||
}
|
||||
|
||||
bool GatherActor::CheckLaunchCondition(OpContext<DeviceTensor> *const context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
// Fetch input data.
|
||||
if (input_datas_num_ != 0) {
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter == input_data_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() + device_tensor_store_keys_.size() != input_datas_num_) {
|
||||
return false;
|
||||
}
|
||||
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[](const auto &input_stack) { return input_stack.second.empty(); })) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch input control.
|
||||
if (input_controls_num_ != 0) {
|
||||
auto control_iter = input_op_controls_.find(context->sequential_num_);
|
||||
if (control_iter == input_op_controls_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (control_iter->second.size() != input_controls_num_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch input branch id.
|
||||
if (need_branch_id_input_) {
|
||||
auto branch_id_iter = input_branch_ids_.find(context->sequential_num_);
|
||||
if (branch_id_iter == input_branch_ids_.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void GatherActor::EraseInput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
// Erase input data.
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter != input_data_.end() && std::all_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[](const auto &input_data) { return input_data.second.empty(); })) {
|
||||
auto ret = input_data_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input data failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
|
||||
// Erase input control.
|
||||
if (input_controls_num_ != 0) {
|
||||
auto ret = input_op_controls_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input controls failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
|
||||
// Erase input branch id.
|
||||
if (need_branch_id_input_) {
|
||||
auto ret = input_branch_ids_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input branch id failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -1,143 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <stack>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/control_node_parser.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
||||
constexpr size_t kReturnInputPos = 1;
|
||||
|
||||
// Gather actor is used in three places:
|
||||
// 1. Entrance of sub funcgraph
|
||||
// 2. call node which input0 is a funcgraph
|
||||
// 3. There is some call nodes in the inputs of kernel graph.
|
||||
// Gather actor will be used in the control flow. When the subgraph is called, the real parameters need to be put
|
||||
// together and sent to the subgraph. At the same time, the entry of the subgraph needs to accept input data.
|
||||
// Special in recursion, general inputs and call inputs of the kernel graph are used in stack mode, it needs to be
|
||||
// collected at the entrance of the kernel graph.
|
||||
class GatherActor : public OpActor<DeviceTensor> {
|
||||
public:
|
||||
GatherActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, const bool need_branch_id_input,
|
||||
const AID switch_aid, const AID gather_aid, const int branch_id)
|
||||
: OpActor(name),
|
||||
data_nodes_(parameters),
|
||||
need_branch_id_input_(need_branch_id_input),
|
||||
switch_aid_(switch_aid),
|
||||
gather_aid_(gather_aid),
|
||||
local_branch_id_(branch_id) {
|
||||
device_contexts_.resize(parameters.size());
|
||||
}
|
||||
~GatherActor() override = default;
|
||||
|
||||
// Get the index of the parameter, the data_node needs to be the front node.
|
||||
size_t FetchDataNodePosition(const KernelWithIndex &data_node) const;
|
||||
|
||||
// The gather actor run when receive the input data.
|
||||
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override;
|
||||
// The gather actor run when receive the input control.
|
||||
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
|
||||
// The gather actor run when receive the input branch id.
|
||||
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context);
|
||||
void Init() override;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
// Collect the inputs of gather actor.
|
||||
void FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser);
|
||||
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||
// Check whether satisfy the condition for launch.
|
||||
bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const;
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) const;
|
||||
// Erase input data and input controls when finish gather launch.
|
||||
void EraseInput(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The device tensors for launch.
|
||||
std::vector<DeviceTensor *> input_device_tensors_;
|
||||
// The branch if for current step.
|
||||
int input_branch_id_{kInvalidBranchID};
|
||||
|
||||
// Input data.
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
|
||||
// Input branch ids is used to record the id corresponding receive from gather actor.
|
||||
// In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
|
||||
// places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send
|
||||
// its branch id to the gather actor of the subgraph. Then branch id will be sent by the gather actor to the
|
||||
// switch actor connected to the output.
|
||||
std::unordered_map<uuids::uuid *, int> input_branch_ids_;
|
||||
|
||||
// Output data.
|
||||
// Cache unique output data by output index to modify the output data effectively.
|
||||
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
std::vector<OpData<DeviceTensor> *> output_data_;
|
||||
|
||||
// Output arrows.
|
||||
std::vector<DataArrowPtr> output_result_arrows_;
|
||||
std::vector<AID> output_branch_arrows_;
|
||||
|
||||
// Parameters of sub funcgraph, which is the front node.
|
||||
std::vector<KernelWithIndex> data_nodes_;
|
||||
std::vector<DeviceContext *> device_contexts_;
|
||||
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
|
||||
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
|
||||
|
||||
// When the output is a parameter of the subgraph, the gather actor needs to send the anfnode to the output actor,
|
||||
// so all the nodes that may send the device tensor to gather actor are recorded. When the anfnode needs to be sent
|
||||
// to the output actor, the corresponding backend node will be found from the map.
|
||||
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> front_to_backend_parameter_;
|
||||
|
||||
// The dependent input data number.
|
||||
size_t input_datas_num_{0};
|
||||
// The dependent input controls number.
|
||||
size_t input_controls_num_{0};
|
||||
// Whether it needs to accept the branch id. When the gather actor is the input of the subgraph, it needs to receive
|
||||
// branch id sent by the subgraph caller, which will be true at this time.
|
||||
bool need_branch_id_input_;
|
||||
|
||||
// Actor id that needs to send the branch id to it.
|
||||
// When the actor is corresponding to call node, the branch id needs to be sent to the input gather actor and output
|
||||
// switch actor of the called funcgraph. When the actor is the entrance of the funcgraph, the gather actor id is
|
||||
// empty, just need to send branch id to its output switch actor.
|
||||
const AID switch_aid_;
|
||||
const AID gather_aid_;
|
||||
|
||||
// The branch id corresponding to the funcgraph to which the gather actor belongs.
|
||||
int local_branch_id_;
|
||||
};
|
||||
|
||||
using GatherActorPtr = std::shared_ptr<GatherActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_
|
|
@ -1,501 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/gather_actor.h"
|
||||
#include "runtime/framework/actor/memory_manager_actor.h"
|
||||
#include "mindrt/include/async/async.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void SwitchActor::Init() {
|
||||
// Init output data.
|
||||
output_data_.resize(output_branch_arrows_.size());
|
||||
for (size_t i = 0; i < output_branch_arrows_.size(); ++i) {
|
||||
auto &output_branch_arrow = output_branch_arrows_[i];
|
||||
auto &output_data = output_data_[i];
|
||||
for (auto &data_arrow : output_branch_arrow) {
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
|
||||
(void)output_data.emplace_back(std::move(data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
const auto &sequential_num = context->sequential_num_;
|
||||
auto &input_datas = input_data_[sequential_num];
|
||||
input_datas[input_data->index_].push(input_data->data_);
|
||||
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
if (input_controls_[sequential_num].find(input_control) == input_controls_[sequential_num].end()) {
|
||||
input_controls_[sequential_num][input_control] = 0;
|
||||
}
|
||||
input_controls_[sequential_num][input_control]++;
|
||||
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
input_branch_ids_[sequential_num].push(branch_id);
|
||||
}
|
||||
|
||||
void SwitchActor::ParseInput(const ControlNodeParserPtr &parser) {
|
||||
std::vector<AnfNodePtr> inputs = node_->inputs();
|
||||
|
||||
if (IsPrimitive(inputs[0], prim::kPrimSwitch)) {
|
||||
ParseSwitchInput();
|
||||
} else if (IsPrimitive(inputs[0], prim::kPrimReturn)) {
|
||||
ParseReturnInput(parser);
|
||||
} else {
|
||||
ParseSwitchLayerInput();
|
||||
}
|
||||
backend_parameters_.resize(input_nodes_.size());
|
||||
}
|
||||
|
||||
void SwitchActor::ParsePartialInput(const AnfNodePtr &node, const size_t branch_id) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
|
||||
// The inputs of the Partial node is:
|
||||
// [0] ValueNode<Primitive> kPartial.
|
||||
// [1] ValueNode<FuncGraphPtr>.
|
||||
// [2..] Inputs.
|
||||
auto partial_inputs = cnode->inputs();
|
||||
if (partial_inputs.size() <= kPartialFuncGraphPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(cnode);
|
||||
}
|
||||
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
|
||||
|
||||
branch_func_graph_[branch_id] = func_graph;
|
||||
for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
|
||||
AddInput(partial_inputs[j], branch_id);
|
||||
}
|
||||
} else if (IsValueNode<FuncGraph>(node)) {
|
||||
const auto func_graph = GetValueNode<FuncGraphPtr>(node);
|
||||
branch_func_graph_[branch_id] = func_graph;
|
||||
} else {
|
||||
AddInput(node, branch_id);
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::InitVectorSize(const size_t num) {
|
||||
branch_inputs_pos_.resize(num);
|
||||
branch_func_graph_.resize(num);
|
||||
output_branch_arrows_.resize(num);
|
||||
output_branch_result_arrows_.resize(num);
|
||||
output_branch_control_arrows_.resize(num);
|
||||
output_branch_branch_arrows_.resize(num);
|
||||
}
|
||||
|
||||
void SwitchActor::ParseReturnInput(const ControlNodeParserPtr &parser) {
|
||||
const auto &func_graph = node_->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &call_num = parser->GetCallNumByFuncGraph(func_graph);
|
||||
InitVectorSize(call_num);
|
||||
|
||||
AddCommonInput(func_graph->output());
|
||||
}
|
||||
|
||||
void SwitchActor::ParseSwitchInput() {
|
||||
// The inputs of the switch node:
|
||||
// [0] ValueNode<Primitive> kSwitch.
|
||||
// [1] switch condition.
|
||||
// [2] Partial node: true branch.
|
||||
// [3] Partial node: false branch.
|
||||
std::vector<AnfNodePtr> inputs = node_->inputs();
|
||||
if (inputs.size() != kSwitchInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is not equal 4";
|
||||
}
|
||||
|
||||
InitVectorSize(kSwitchPartialNum);
|
||||
|
||||
const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0);
|
||||
input_nodes_.push_back(cond_node);
|
||||
input_datas_num_++;
|
||||
// Init the two branches of switch node.
|
||||
ParsePartialInput(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
|
||||
ParsePartialInput(inputs[kSwitchTrueBranchPos], static_cast<size_t>(true));
|
||||
}
|
||||
|
||||
void SwitchActor::ParseSwitchLayerInput() {
|
||||
// The inputs of the switch node:
|
||||
// [0] ValueNode<Primitive> kSwitchLayer.
|
||||
// [1] switchLayer index.
|
||||
// [2] MakeTuple node: tuple of branches.
|
||||
std::vector<AnfNodePtr> inputs = node_->inputs();
|
||||
if (inputs.size() != kSwitchLayerInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitchLayer->name() << " is not equal 3";
|
||||
}
|
||||
|
||||
const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchLayerCondPos], 0);
|
||||
input_nodes_.push_back(cond_node);
|
||||
input_datas_num_++;
|
||||
|
||||
// The second input of SwitchLayer is maketuple node, which includes all branches.
|
||||
auto branch_nodes = inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
|
||||
InitVectorSize(branch_nodes.size() - 1);
|
||||
|
||||
// Parse all branches.
|
||||
for (size_t i = kMakeTupleInputStartPos; i < branch_nodes.size(); ++i) {
|
||||
if (AnfAlgo::CheckPrimitiveType(branch_nodes[i], prim::kPrimPartial)) {
|
||||
ParsePartialInput(branch_nodes[i], i - kMakeTupleInputStartPos);
|
||||
} else if (branch_nodes[i]->isa<ValueNode>()) {
|
||||
branch_func_graph_[i - 1] = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::AddCommonInput(const AnfNodePtr &node) {
|
||||
for (size_t i = 0; i < branch_inputs_pos_.size(); ++i) {
|
||||
AddInput(node, i);
|
||||
}
|
||||
}
|
||||
|
||||
size_t SwitchActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
|
||||
const auto data_node_with_index = AnfAlgo::VisitKernelWithReturnType(data_node, 0);
|
||||
const auto &iter = find(input_nodes_.begin(), input_nodes_.end(), data_node_with_index);
|
||||
if (iter == input_nodes_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node)
|
||||
<< " is not exist in switch actor:" << GetAID();
|
||||
}
|
||||
return iter - input_nodes_.begin();
|
||||
}
|
||||
|
||||
void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t branch) {
|
||||
const auto &node = node_with_index.first;
|
||||
|
||||
// The value node and weight node need to be placed in the device store. The switch actor has three inputs:
|
||||
// 1) The input of the switch is the value node.
|
||||
// 2) There is a weight node or value node in the return of the sub funcgraph.
|
||||
if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && node->isa<Parameter>() && HasAbstractRef(node)) ||
|
||||
node->isa<ValueNode>()) {
|
||||
const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
|
||||
if (iter != input_nodes_.end()) {
|
||||
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
|
||||
return;
|
||||
}
|
||||
(void)device_tensor_store_keys_.emplace_back(input_nodes_.size(), node.get());
|
||||
branch_inputs_pos_[branch].push_back(input_nodes_.size());
|
||||
input_nodes_.push_back(node_with_index);
|
||||
return;
|
||||
}
|
||||
|
||||
// Output of updatestate node is U, need to be skipped.
|
||||
if (node->isa<Parameter>() && HasAbstractRef(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Add parameter.
|
||||
auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
|
||||
if (iter == input_nodes_.end()) {
|
||||
branch_inputs_pos_[branch].push_back(input_nodes_.size());
|
||||
input_nodes_.push_back(node_with_index);
|
||||
++input_datas_num_;
|
||||
} else {
|
||||
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState) || HasAbstractMonad(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto &real_input = AnfAlgo::VisitKernelWithReturnType(node, 0);
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(real_input.first, prim::kPrimMakeTuple)) {
|
||||
const auto &inputs = real_input.first->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
|
||||
AddInput(inputs[i], branch);
|
||||
}
|
||||
} else if (IsCallNode(real_input.first)) {
|
||||
std::vector<AnfNodePtr> call_nodes;
|
||||
const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes);
|
||||
if (call_output_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first);
|
||||
}
|
||||
for (size_t i = 0; i < call_output_num; ++i) {
|
||||
AddInput({real_input.first, i}, branch);
|
||||
}
|
||||
} else if (real_input.first->isa<ValueNode>() && real_input.first->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
|
||||
const auto &value = real_input.first->cast<ValueNodePtr>()->value();
|
||||
const auto &tuple_value = value->cast<ValueTuplePtr>();
|
||||
for (size_t i = 0; i < tuple_value->value().size(); ++i) {
|
||||
AddInput({real_input.first, i}, branch);
|
||||
}
|
||||
} else {
|
||||
AddInput(real_input, branch);
|
||||
}
|
||||
}
|
||||
|
||||
size_t SwitchActor::GetIndex(const OpContext<DeviceTensor> *const context) {
|
||||
if (need_branch_id_input_) {
|
||||
if (input_branch_ids_.find(context->sequential_num_) == input_branch_ids_.end() ||
|
||||
input_branch_ids_[context->sequential_num_].empty()) {
|
||||
MS_LOG(ERROR) << "Invalid branch id for actor:" + GetAID().Name();
|
||||
}
|
||||
auto branch_id = input_branch_ids_[context->sequential_num_].top();
|
||||
input_branch_ids_[context->sequential_num_].pop();
|
||||
if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) {
|
||||
MS_LOG(ERROR) << "Invalid branch id for switch actor:" + GetAID().Name() +
|
||||
" branch id:" + std::to_string(branch_id);
|
||||
}
|
||||
return branch_id_to_index_[branch_id];
|
||||
}
|
||||
|
||||
DeviceTensor *device_tensor = input_device_tensors_[0];
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
|
||||
auto inputs = node_->inputs();
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0);
|
||||
size_t size = abstract::TypeIdSize(type_id);
|
||||
if (size > sizeof(int64_t)) {
|
||||
MS_LOG(ERROR) << "Index must be Int type.";
|
||||
}
|
||||
|
||||
int64_t index = 0;
|
||||
char buf[kMaxSwitchCondSize] = {0};
|
||||
ShapeVector host_shape;
|
||||
if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) {
|
||||
MS_LOG(ERROR) << GetAID().Name() << " get index from device address failed, type id:" << std::to_string(type_id)
|
||||
<< ", device type:" << std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
|
||||
}
|
||||
|
||||
if (type_id == TypeId::kNumberTypeInt32) {
|
||||
index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
|
||||
} else if (type_id == TypeId::kNumberTypeInt64) {
|
||||
index = (static_cast<int64_t *>(static_cast<void *>(buf)))[0];
|
||||
} else if (type_id == TypeId::kNumberTypeBool) {
|
||||
bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
|
||||
index = static_cast<int64_t>(cond ? 1 : 0);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Index must be Int type.";
|
||||
}
|
||||
|
||||
// SwitchLayer node support negative index range [-size, -1].
|
||||
if (index < 0) {
|
||||
index += SizeToInt(branch_func_graph_.size());
|
||||
}
|
||||
return static_cast<size_t>(index);
|
||||
}
|
||||
|
||||
bool SwitchActor::CheckLaunchCondition(OpContext<DeviceTensor> *const context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter == input_data_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_datas_num_) {
|
||||
return false;
|
||||
}
|
||||
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[](const auto &input_stack) { return input_stack.second.empty(); })) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
auto data_iter = input_controls_.find(context->sequential_num_);
|
||||
if (data_iter == input_controls_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_controls_num_) {
|
||||
return false;
|
||||
}
|
||||
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[](const auto &input_stack) { return input_stack.second == 0; })) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
input_device_tensors_.resize(input_nodes_.size());
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter != input_data_.end()) {
|
||||
for (auto &input_data : data_iter->second) {
|
||||
input_device_tensors_[input_data.first] = input_data.second.top();
|
||||
input_data.second.pop();
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
|
||||
auto device_tensor =
|
||||
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
|
||||
if (device_tensor == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
|
||||
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
|
||||
}
|
||||
|
||||
auto control_iter = input_controls_.find(context->sequential_num_);
|
||||
if (control_iter != input_controls_.end()) {
|
||||
(void)for_each(control_iter->second.begin(), control_iter->second.end(),
|
||||
[](auto &input_control) { input_control.second--; });
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto index = GetIndex(context);
|
||||
if (index >= output_branch_arrows_.size()) {
|
||||
std::string error_info = "Switch actor:" + GetAID().Name() + " invalid index:" + std::to_string(index);
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
||||
// Must be the execution order: send branch id --> send result --> send data --> send control, avoid the illegal
|
||||
// timing problem.
|
||||
// 1.Send branch id.
|
||||
if (local_branch_id_ >= 0) {
|
||||
const auto &branch_arrows = output_branch_branch_arrows_[index];
|
||||
for (const auto &branch_arrow : branch_arrows) {
|
||||
Async(branch_arrow, &GatherActor::CollectBranchId, local_branch_id_, context);
|
||||
}
|
||||
}
|
||||
|
||||
// 2.Send result.
|
||||
auto &output_branch_result_arrow = output_branch_result_arrows_[index];
|
||||
for (size_t i = 0; i < output_branch_result_arrow.size(); ++i) {
|
||||
auto &result_arrow = output_branch_result_arrow[i];
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
if (result_arrow->from_output_index_ >= SizeToInt(branch_inputs_pos_[index].size())) {
|
||||
std::string error_info =
|
||||
"Invalid from index in switch actor, from index:" + std::to_string(result_arrow->from_output_index_) +
|
||||
" total:" + std::to_string(branch_inputs_pos_[index].size()) + " actor:" + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
size_t from_index = branch_inputs_pos_[index][IntToSize(result_arrow->from_output_index_)];
|
||||
|
||||
MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index];
|
||||
bool is_send = false;
|
||||
for (const auto &backend_node : backend_parameters_[from_index]) {
|
||||
for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(backend_node.first); ++j) {
|
||||
if (backend_node.first->kernel_info() != nullptr && AnfAlgo::OutputAddrExist(backend_node.first, j, false) &&
|
||||
AnfAlgo::GetMutableOutputAddr(backend_node.first, j, false).get() == input_device_tensors_[from_index]) {
|
||||
auto output_index = j;
|
||||
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, output_index,
|
||||
result_arrow->to_input_index_, context);
|
||||
is_send = true;
|
||||
MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index]
|
||||
<< " succeed";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!is_send) {
|
||||
std::string error_info = "Failed to get backend node of switch actor output, actor:" + GetAID().Name() +
|
||||
" branch:" + std::to_string(index) +
|
||||
" index:" + std::to_string(result_arrow->from_output_index_) + " output pos" +
|
||||
std::to_string(branch_inputs_pos_[index][IntToSize(result_arrow->from_output_index_)]) +
|
||||
" output index" + std::to_string(result_arrow->to_input_index_);
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
|
||||
// 3.Send Data.
|
||||
auto &output_branch_arrow = output_branch_arrows_[index];
|
||||
auto &output_data = output_data_[index];
|
||||
for (size_t i = 0; i < output_branch_arrow.size(); ++i) {
|
||||
auto &data_arrow = output_branch_arrow[i];
|
||||
auto &data = output_data[i];
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
data->data_ = input_device_tensors_[IntToSize(data_arrow->from_output_index_)];
|
||||
Async(data_arrow->to_op_id_, &OpActor::RunOpData, data.get(), context);
|
||||
}
|
||||
|
||||
// 4.Send output control.
|
||||
auto source_aid = const_cast<AID *>(&GetAID());
|
||||
for (auto &output_control : output_branch_control_arrows_[index]) {
|
||||
Async(output_control, &OpActor::RunOpControl, source_aid, context);
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::EraseInput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto data_iter = input_data_.find(context->sequential_num_);
|
||||
if (data_iter != input_data_.end() && std::all_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[](const auto &input_data) { return input_data.second.empty(); })) {
|
||||
auto ret = input_data_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input data failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
auto control_iter = input_controls_.find(context->sequential_num_);
|
||||
if (control_iter != input_controls_.end() &&
|
||||
std::all_of(control_iter->second.begin(), control_iter->second.end(),
|
||||
[](const auto &input_control) { return input_control.second == 0; })) {
|
||||
auto ret = input_controls_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input control failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context);
|
||||
}
|
||||
|
||||
void SwitchActor::FetchInputNode(const ControlNodeParserPtr &parser) {
|
||||
for (size_t i = 0; i < input_nodes_.size(); ++i) {
|
||||
const auto &input_node = input_nodes_[i].first;
|
||||
if (!(input_node->isa<Parameter>() && HasAbstractRef(input_node))) {
|
||||
backend_parameters_[i] = parser->FetchBackendInputNodeByFrontNode(input_node);
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &backend_weight = parser->FetchBackendNodebyWeightNode(input_node);
|
||||
if (backend_weight == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for weight node:" << AnfAlgo::GetNodeDebugString(input_node);
|
||||
}
|
||||
(void)backend_parameters_[i].emplace(backend_weight, 0);
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -1,180 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <stack>
|
||||
#include <unordered_map>
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "runtime/framework/control_node_parser.h"
|
||||
#include "mindrt/include/actor/switch_actor.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
using mindspore::session::KernelWithIndex;
|
||||
|
||||
constexpr size_t kSwitchInputNum = 4;
|
||||
constexpr size_t kSwitchCondPos = 1;
|
||||
constexpr size_t kSwitchPartialNum = 2;
|
||||
constexpr size_t kSwitchLayerCondPos = 1;
|
||||
constexpr size_t kSwitchLayerBranchPos = 2;
|
||||
constexpr size_t kSwitchLayerInputNum = 3;
|
||||
constexpr size_t kMaxSwitchCondSize = 8;
|
||||
constexpr size_t kSwitchTrueBranchPos = 2;
|
||||
constexpr size_t kSwitchFalseBranchPos = 3;
|
||||
constexpr size_t kPartialFuncGraphPos = 1;
|
||||
constexpr size_t kPartialInputStartPos = 2;
|
||||
constexpr size_t kCallInputStartPos = 1;
|
||||
constexpr size_t kMakeTupleInputStartPos = 1;
|
||||
|
||||
// Switch actor is used to execute the branch according to the input condition.
|
||||
// Switch and SwitchLayer node will be converted to switch actor.
|
||||
// The execution process is divided into:
|
||||
// 1. Put input into the vector.
|
||||
// 2. Check whether the input condition has been received.
|
||||
// 3. Check whether all input from the branch corresponding to the index has been received.
|
||||
// 4. Send the data to the corresponding branch.
|
||||
// 5. Free Memory
|
||||
class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
||||
public:
|
||||
SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node, const int branch_id,
|
||||
const bool need_branch_id_input)
|
||||
: SwitchActorBase(name),
|
||||
device_context_(device_context),
|
||||
node_(node),
|
||||
local_branch_id_(branch_id),
|
||||
need_branch_id_input_(need_branch_id_input) {}
|
||||
~SwitchActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
||||
// The switch actor run when receive the input data.
|
||||
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
|
||||
// The switch actor run when receive the input control.
|
||||
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context);
|
||||
// The switch actor run when receive the input branch id.
|
||||
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context);
|
||||
// Parse the input node information of the switch actor according to node_.
|
||||
void ParseInput(const ControlNodeParserPtr &parser);
|
||||
// Add input for all branches.
|
||||
void AddCommonInput(const AnfNodePtr &node);
|
||||
void AddSingleInput(const AnfNodePtr &node, size_t branch) { AddInput(node, branch); }
|
||||
// Fetch the input position of the data node.
|
||||
size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
void ParsePartialInput(const AnfNodePtr &node, const size_t branch_id);
|
||||
void ParseSwitchInput();
|
||||
void ParseSwitchLayerInput();
|
||||
// In control flow, the output of each subgraph is connected to a switch actor, and the switch actor is
|
||||
// initialized with the return node of the subgraph.
|
||||
void ParseReturnInput(const ControlNodeParserPtr &parser);
|
||||
// Initialize the size of the vector members.
|
||||
void InitVectorSize(const size_t num);
|
||||
// Get index from DeviceTensor.
|
||||
size_t GetIndex(const OpContext<DeviceTensor> *const context);
|
||||
// Add input for the branch.
|
||||
void AddInput(const AnfNodePtr &node, size_t branch);
|
||||
void AddInput(const KernelWithIndex node_with_index, const size_t branch);
|
||||
|
||||
// Check whether satisfy the condition for send outputs.
|
||||
bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const;
|
||||
// Fetch the args of switch branch.
|
||||
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||
void SendOutput(OpContext<DeviceTensor> *const context);
|
||||
// Erase input data and input controls when finish switch launch.
|
||||
void EraseInput(OpContext<DeviceTensor> *const context);
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// Collect all the backend inputs of switch actor.
|
||||
void FetchInputNode(const ControlNodeParserPtr &parser);
|
||||
// All inputs of the switch actor, include weight and tensor.
|
||||
// Used to receive input data, the first input is the condition of switch.
|
||||
std::vector<KernelWithIndex> input_nodes_;
|
||||
// The position of the branch output in the input_nodes_.
|
||||
std::vector<std::vector<size_t>> branch_inputs_pos_;
|
||||
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
|
||||
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<AID *, size_t>> input_controls_;
|
||||
|
||||
// Branch ids is used to record the id corresponding to the switch output branch.
|
||||
// In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
|
||||
// places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send
|
||||
// its branch id to the gather of the subgraph. Then branch id will be sent by the gather actor to the switch
|
||||
// actor connected to the output.
|
||||
// In a recursive scenario, the switch will sequentially receive the branch ids sent by the caller, and the switch
|
||||
// actor needs to store the branch ids in the stack, and pop up in turn when returning.
|
||||
std::unordered_map<uuids::uuid *, std::stack<int>> input_branch_ids_;
|
||||
|
||||
// Control arrows of different branches.
|
||||
std::vector<std::vector<AID>> output_branch_control_arrows_;
|
||||
// Branch id arrows of different branches.
|
||||
std::vector<std::vector<AID>> output_branch_branch_arrows_;
|
||||
// Result arrows of different branches.
|
||||
std::vector<std::vector<DataArrowPtr>> output_branch_result_arrows_;
|
||||
|
||||
// When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor,
|
||||
// so all the nodes that may send the device tensor to switch actor are recorded.
|
||||
std::vector<std::set<KernelWithIndex>> backend_parameters_;
|
||||
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
|
||||
|
||||
std::vector<FuncGraphPtr> branch_func_graph_;
|
||||
|
||||
std::unordered_map<int, size_t> branch_id_to_index_;
|
||||
|
||||
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
|
||||
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
|
||||
|
||||
std::vector<DeviceTensor *> input_device_tensors_;
|
||||
|
||||
// Save the DeviceContext of input_nodes_, which is used to release the DeviceTensor.
|
||||
const DeviceContext *device_context_;
|
||||
|
||||
// The id of memory manager actor. Send message to it for alloc and free memory.
|
||||
const AID memory_manager_aid_;
|
||||
// The dependent input data number.
|
||||
size_t input_datas_num_{0};
|
||||
// The dependent input controls number.
|
||||
size_t input_controls_num_{0};
|
||||
CNodePtr node_;
|
||||
|
||||
// The branch id corresponding to the funcgraph to which the gather actor belongs.
|
||||
int local_branch_id_;
|
||||
// Whether it needs to accept the branch id. When the switch actor is the output of the subgraph, it needs to receive
|
||||
// branch id sent by the gather actor of subgraph, which will be true at this time.
|
||||
bool need_branch_id_input_;
|
||||
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;
|
||||
};
|
||||
|
||||
using SwitchActorPtr = std::shared_ptr<SwitchActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
|
|
@ -15,8 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "runtime/framework/control_node_parser.h"
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "runtime/framework/actor/gather_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/switch_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/gather_actor.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
|
|
|
@ -37,6 +37,17 @@ using mindspore::session::KernelWithIndex;
|
|||
constexpr int kInvalidBranchID = -1;
|
||||
constexpr int kMainBranchID = 0;
|
||||
constexpr int kSubBranchStartID = 1;
|
||||
constexpr size_t kSwitchInputNum = 4;
|
||||
constexpr size_t kSwitchPartialNum = 2;
|
||||
constexpr size_t kSwitchLayerCondPos = 1;
|
||||
constexpr size_t kSwitchLayerBranchPos = 2;
|
||||
constexpr size_t kSwitchLayerInputNum = 3;
|
||||
constexpr size_t kSwitchTrueBranchPos = 2;
|
||||
constexpr size_t kSwitchFalseBranchPos = 3;
|
||||
constexpr size_t kPartialFuncGraphPos = 1;
|
||||
constexpr size_t kPartialInputStartPos = 2;
|
||||
constexpr size_t kCallInputStartPos = 1;
|
||||
constexpr size_t kMakeTupleInputStartPos = 1;
|
||||
|
||||
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;
|
||||
|
|
|
@ -368,8 +368,6 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
|
|||
actor_set->output_actor_ = BuildOutputActor(graph_compiler_info);
|
||||
actor_set->data_prepare_actor_ =
|
||||
BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
|
||||
actor_set->switch_actors_ = BuildSwitchActor(graph_compiler_info);
|
||||
actor_set->gather_actors_ = BuildGatherActor(graph_compiler_info);
|
||||
|
||||
return actor_set;
|
||||
}
|
||||
|
@ -732,155 +730,6 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorS
|
|||
return no_input_kernel_actors;
|
||||
}
|
||||
|
||||
std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
std::vector<SwitchActorPtr> switch_actors;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
|
||||
for (const auto &pair : front_node_to_actor_) {
|
||||
front_to_backend_kernel[pair.first] = pair.second->kernel_;
|
||||
}
|
||||
|
||||
// Build switch actor by switch node and switchlayer node.
|
||||
for (const auto &control_node : graph_compiler_info.control_nodes_) {
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
|
||||
const auto func_graph = control_node->func_graph();
|
||||
const auto branch_id = graph_compiler_info.control_node_parser_->GetBranchIDByFuncGraph(func_graph);
|
||||
const auto &actor_name = control_node->DebugString();
|
||||
auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
|
||||
control_node->cast<CNodePtr>(), branch_id, false);
|
||||
switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
|
||||
|
||||
// Fetch all the input nodes of switch actor.
|
||||
switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
|
||||
InsertActor(switch_actor.get());
|
||||
(void)switch_actors.emplace_back(switch_actor);
|
||||
}
|
||||
}
|
||||
|
||||
// Build switch actor by return node.
|
||||
const auto func_graphs_to_call_num = graph_compiler_info.control_node_parser_->func_graph_to_call_num_;
|
||||
for (const auto &func_graph_to_call_num : func_graphs_to_call_num) {
|
||||
const auto &return_node = func_graph_to_call_num.first->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
const auto &actor_name = return_node->DebugString();
|
||||
auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
|
||||
return_node->cast<CNodePtr>(), kInvalidBranchID, true);
|
||||
switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
|
||||
|
||||
// Fetch all the input nodes of switch actor.
|
||||
switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
|
||||
InsertActor(switch_actor.get());
|
||||
(void)switch_actors.emplace_back(switch_actor);
|
||||
}
|
||||
|
||||
return switch_actors;
|
||||
}
|
||||
|
||||
std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
std::vector<GatherActorPtr> gather_actors;
|
||||
|
||||
const auto &loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor";
|
||||
const auto &loop_count_actor = FetchActor(loop_count_actor_name);
|
||||
if (loop_count_actor == nullptr) {
|
||||
return gather_actors;
|
||||
}
|
||||
|
||||
const auto &output_actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
|
||||
const auto &output_actor = FetchActor(output_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(output_actor);
|
||||
|
||||
const auto parser = graph_compiler_info.control_node_parser_;
|
||||
|
||||
bool is_main_return = true;
|
||||
// Each funcgraph has a return node, get the funcgraph from the return node, and create a gather actor.
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
|
||||
for (const auto &pair : front_node_to_actor_) {
|
||||
front_to_backend_kernel[pair.first] = pair.second->kernel_;
|
||||
}
|
||||
|
||||
for (const auto &control_node : graph_compiler_info.control_nodes_) {
|
||||
const auto &func_graph = control_node->func_graph();
|
||||
const auto &cnode = control_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
const auto &return_node = func_graph->get_return();
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
|
||||
// Root funcgraph does not need to create a gather actor.
|
||||
if (is_main_return) {
|
||||
is_main_return = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) {
|
||||
continue;
|
||||
}
|
||||
auto actor_name = func_graph->ToString();
|
||||
std::vector<KernelWithIndex> parameters;
|
||||
for (const auto ¶meter : func_graph->get_inputs()) {
|
||||
if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) {
|
||||
continue;
|
||||
}
|
||||
(void)parameters.emplace_back(parameter, 0);
|
||||
}
|
||||
|
||||
const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
|
||||
|
||||
const auto &output_switch_actor = FetchActor(return_node->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(output_switch_actor);
|
||||
const auto &output_switch_aid = output_switch_actor->GetAID();
|
||||
|
||||
auto gather_actor =
|
||||
std::make_shared<GatherActor>(actor_name, parameters, true, output_switch_aid, AID(), branch_id);
|
||||
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
|
||||
InsertActor(gather_actor.get());
|
||||
(void)gather_actors.emplace_back(gather_actor);
|
||||
}
|
||||
}
|
||||
|
||||
// Create gather actor for call node which input0 of call node is a funcgraph.
|
||||
for (const auto &control_node : graph_compiler_info.control_nodes_) {
|
||||
const auto &cnode = control_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
|
||||
if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
|
||||
// Collect the parameters.
|
||||
std::vector<KernelWithIndex> parameters;
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]))) {
|
||||
continue;
|
||||
}
|
||||
(void)parameters.emplace_back(inputs[i], 0);
|
||||
}
|
||||
|
||||
auto func_graph = control_node->func_graph();
|
||||
auto actor_name = control_node->DebugString();
|
||||
const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
|
||||
const auto &to_func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
const auto &to_actor = FetchActor(to_func_graph->ToString());
|
||||
auto gather_actor =
|
||||
std::make_shared<GatherActor>(actor_name, parameters, false, AID(), to_actor->GetAID(), branch_id);
|
||||
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
|
||||
|
||||
InsertActor(gather_actor.get());
|
||||
(void)gather_actors.emplace_back(gather_actor);
|
||||
}
|
||||
}
|
||||
|
||||
// Create gather actor for kernel graph which has a call input.
|
||||
const auto &graph_with_device_contexts = graph_compiler_info.control_node_parser_->call_input_kernel_graphs_;
|
||||
for (const auto &graph_with_device_context : graph_with_device_contexts) {
|
||||
const auto &graph = graph_with_device_context.first;
|
||||
const auto ¶meters = FetchParameterbyKernelGraph(graph);
|
||||
|
||||
auto actor_name = graph->ToString();
|
||||
auto gather_actor = std::make_shared<GatherActor>(actor_name, parameters, false, AID(), AID(), kInvalidBranchID);
|
||||
InsertActor(gather_actor.get());
|
||||
(void)gather_actors.emplace_back(gather_actor);
|
||||
}
|
||||
|
||||
return gather_actors;
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
|
||||
const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
|
||||
const KernelWithIndex &to_kernel_with_input_idx) {
|
||||
|
@ -1547,81 +1396,6 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
const ActorSet *actor_set) {
|
||||
const auto &to_actor = actor_set->output_actor_;
|
||||
const auto &loop_count_actor = actor_set->loop_count_actor_;
|
||||
if (to_actor == nullptr || loop_count_actor == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// When there is a call node in the output, the output will be sent to the output actor by the switch actor of
|
||||
// the funcgraph called by the call node.
|
||||
const auto &outputs = graph_compiler_info.origin_outputs_order_;
|
||||
for (const auto &output : outputs) {
|
||||
const auto &output_node = output.first.first;
|
||||
const auto &output_index = output.first.second;
|
||||
const auto output_poses = output.second;
|
||||
|
||||
if (IsCallNode(output_node)) {
|
||||
const auto &func_graphs = FetchFuncGraphbyCallNode(output_node);
|
||||
for (const auto func_graph : func_graphs) {
|
||||
const auto &actor_name = func_graph->get_return()->DebugString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
// Set branch index into switch actor.
|
||||
size_t branch_index = switch_actor->branch_id_to_index_.size();
|
||||
if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
|
||||
branch_index = switch_actor->branch_id_to_index_[kMainBranchID];
|
||||
} else {
|
||||
switch_actor->branch_id_to_index_[kMainBranchID] = branch_index;
|
||||
}
|
||||
|
||||
// Link output result arrow.
|
||||
for (const auto output_pos : output_poses) {
|
||||
auto op_arrow = std::make_shared<DataArrow>(output_index, to_actor->GetAID(), output_pos);
|
||||
to_actor->device_contexts_[output_pos] = switch_actor->device_context_;
|
||||
(void)switch_actor->output_branch_result_arrows_[branch_index].emplace_back(op_arrow);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto &switch_actors = actor_set->switch_actors_;
|
||||
for (const auto &from_actor : switch_actors) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto origin_output_with_index = KernelWithIndex(from_actor->node_, 0);
|
||||
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
|
||||
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the switch actor is in the output list, the output of switch actor should be sent to the output actor.
|
||||
// And need to link a control arrow to the loop count actor.
|
||||
for (const auto pos : iter->second) {
|
||||
to_actor->device_contexts_[pos] = from_actor->device_context_;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < from_actor->branch_inputs_pos_.size(); ++i) {
|
||||
const auto &input_pos = from_actor->branch_inputs_pos_[i];
|
||||
if (input_pos.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input num in switch actor:" << from_actor->GetAID();
|
||||
}
|
||||
|
||||
for (const auto pos : iter->second) {
|
||||
auto op_arrow = std::make_shared<DataArrow>(0, to_actor->GetAID(), pos);
|
||||
(void)from_actor->output_branch_result_arrows_[i].emplace_back(op_arrow);
|
||||
}
|
||||
|
||||
(void)from_actor->output_branch_control_arrows_[i].emplace_back(loop_count_actor->GetAID());
|
||||
}
|
||||
loop_count_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors) {
|
||||
const size_t kNeedUpdateDeviceTensorStoreNum = 2;
|
||||
for (auto &kernel_actor : auto_monad_actors) {
|
||||
|
@ -1673,515 +1447,15 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
for (const auto &node : control_nodes) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
// Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor.
|
||||
if (inputs[0]->isa<CNode>()) {
|
||||
auto actor = FetchActor(inputs[0]->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (HasAbstractMonad(inputs[i])) {
|
||||
continue;
|
||||
}
|
||||
switch_actor->AddCommonInput(inputs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *const actor_set) {
|
||||
PrepareInputNodeForSwitchActor(graph_compiler_info.control_nodes_);
|
||||
|
||||
for (const auto &node : graph_compiler_info.control_nodes_) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
const auto &from_func_graph = node->func_graph();
|
||||
auto inputs = cnode->inputs();
|
||||
// Link data arrow for switch node.
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) {
|
||||
auto actor = actor_name_to_actor_[node->DebugString()];
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
|
||||
} else if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
|
||||
// Link the data arrow for the input of the call node.
|
||||
const auto &actor_name = node->DebugString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &to_actor = FetchActor(func_graph->ToString());
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
||||
size_t persist_input_num = 0;
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
if (inputs[i]->isa<ValueNode>()) {
|
||||
const auto &node_value = inputs[i]->cast<ValueNodePtr>()->value();
|
||||
if (!node_value->isa<tensor::Tensor>()) {
|
||||
persist_input_num++;
|
||||
continue;
|
||||
}
|
||||
|
||||
(void)gather_actor->device_tensor_store_keys_.emplace_back(i - kCallInputStartPos - persist_input_num,
|
||||
inputs[i].get());
|
||||
gather_actor->device_contexts_[i - kCallInputStartPos - persist_input_num] =
|
||||
graph_compiler_info.control_node_parser_->GetFrontValueNodeDeviceContext(inputs[i]);
|
||||
} else if ((inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]->cast<ParameterPtr>())) ||
|
||||
AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState) || HasAbstractMonad(inputs[i])) {
|
||||
persist_input_num++;
|
||||
continue;
|
||||
} else {
|
||||
const auto &input_with_index = AnfAlgo::VisitKernelWithReturnType(inputs[i], 0);
|
||||
LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, actor,
|
||||
i - kCallInputStartPos - persist_input_num);
|
||||
}
|
||||
|
||||
auto op_arrow = std::make_shared<DataArrow>(i - kCallInputStartPos - persist_input_num, to_actor->GetAID(),
|
||||
i - kCallInputStartPos - persist_input_num);
|
||||
(void)gather_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Link arrow for switch actor of subgraph output.
|
||||
for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) {
|
||||
const auto &func_graph = func_graph_with_call_num.first;
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto actor = FetchActor(func_graph->get_return()->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
|
||||
}
|
||||
|
||||
// Link arrow for gather actor for call input kernel graph.
|
||||
for (const auto &call_input_kernel_graph : graph_compiler_info.control_node_parser_->call_input_kernel_graphs_) {
|
||||
const auto &kernel_graph = call_input_kernel_graph.first;
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto actor = FetchActor(kernel_graph->ToString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
|
||||
for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) {
|
||||
const auto &input_with_index = gather_actor->data_nodes_[i];
|
||||
const auto &from_func_graph = kernel_graph->GetFuncGraph();
|
||||
LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor, i);
|
||||
}
|
||||
}
|
||||
LinkBranchArrowForSwitchActor(graph_compiler_info);
|
||||
|
||||
LinkBranchArrowForGatherActor(graph_compiler_info);
|
||||
|
||||
LinkControlArrowForGatherActor(&(actor_set->kernel_actors_), graph_compiler_info.graphs_,
|
||||
graph_compiler_info.control_node_parser_);
|
||||
|
||||
LinkControlArrowForSwitchActor(&(actor_set->switch_actors_), actor_set->loop_count_actor_.get(),
|
||||
graph_compiler_info.origin_outputs_order_);
|
||||
|
||||
LinkOutputResultArrowForSwitchActor(graph_compiler_info, actor_set);
|
||||
}
|
||||
void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *const actor_set) {}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *const from_actor, KernelActor *const to_actor,
|
||||
const KernelWithIndex &front_node_with_index,
|
||||
const KernelWithIndex &to_node_with_index) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
|
||||
|
||||
auto position = from_actor->FetchDataNodePosition(front_node_with_index);
|
||||
|
||||
auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_node_with_index.second);
|
||||
(void)from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
to_actor->input_datas_num_++;
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index,
|
||||
const ControlNodeParserPtr &parser, const FuncGraphPtr &from_func_graph,
|
||||
OpActor<DeviceTensor> *const to_actor, const size_t to_index) {
|
||||
// Fetch all the funcgraph that call node would call.
|
||||
const auto cnode = call_node_with_index.first->cast<CNodePtr>();
|
||||
std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(cnode);
|
||||
|
||||
// Collect the output of each funcgraph.
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
const auto actor_name = func_graph->get_return()->DebugString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
const size_t branch_index = switch_actor->branch_id_to_index_.size();
|
||||
|
||||
const auto &func_graph_to_branch_id = parser->func_graph_to_branch_id_;
|
||||
const auto &iter = func_graph_to_branch_id.find(from_func_graph);
|
||||
|
||||
int branch_id = kMainBranchID;
|
||||
if (iter != func_graph_to_branch_id.end()) {
|
||||
branch_id = iter->second;
|
||||
}
|
||||
if (switch_actor->branch_id_to_index_.find(branch_id) != switch_actor->branch_id_to_index_.end()) {
|
||||
LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index,
|
||||
switch_actor->branch_id_to_index_[branch_id]);
|
||||
continue;
|
||||
}
|
||||
LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index, branch_index);
|
||||
switch_actor->branch_id_to_index_[branch_id] = branch_index;
|
||||
}
|
||||
}
|
||||
const KernelWithIndex &to_node_with_index) {}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index,
|
||||
OpActor<DeviceTensor> *to_actor, const size_t to_index,
|
||||
const size_t branch_index) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
size_t start_branch = 0;
|
||||
size_t max_branch = from_actor->output_branch_arrows_.size();
|
||||
if (branch_index != SIZE_MAX) {
|
||||
start_branch = branch_index;
|
||||
max_branch = branch_index + 1;
|
||||
}
|
||||
for (size_t i = start_branch; i < max_branch; ++i) {
|
||||
if (from_actor->branch_inputs_pos_[i].size() <= from_index) {
|
||||
MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i
|
||||
<< " from index:" << from_index << " output size:" << from_actor->branch_inputs_pos_[i].size()
|
||||
<< " to actor:" << to_actor->GetAID() << " to index:" << to_index;
|
||||
}
|
||||
auto op_arrow =
|
||||
std::make_shared<DataArrow>(from_actor->branch_inputs_pos_[i][from_index], to_actor->GetAID(), to_index);
|
||||
(void)from_actor->output_branch_arrows_[i].emplace_back(op_arrow);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info,
|
||||
const KernelWithIndex &input_with_index,
|
||||
const FuncGraphPtr &from_func_graph,
|
||||
OpActor<DeviceTensor> *const to_actor, const size_t to_index) {
|
||||
const auto ¶meters = graph_compiler_info.origin_parameters_order_;
|
||||
const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
|
||||
const auto &input_node = input_with_index.first;
|
||||
|
||||
if (IsCallNode(input_node)) {
|
||||
// The actor input is a call node.
|
||||
LinkDataArrowByCallInput(input_with_index, graph_compiler_info.control_node_parser_, from_func_graph, to_actor,
|
||||
to_index);
|
||||
} else if (IsGatherActor(input_node, actor_name_to_actor_)) {
|
||||
// The actor input is a parameter in gather actor.
|
||||
auto from_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[input_node->func_graph()->ToString()]);
|
||||
auto position = from_actor->FetchDataNodePosition({input_node, 0});
|
||||
auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_index);
|
||||
(void)from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
} else if (IsSwitchActor(input_node)) {
|
||||
const auto &actor_name = input_node->DebugString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
LinkDataArrowForSwitchActor(dynamic_cast<SwitchActor *>(actor), 0, to_actor, to_index);
|
||||
} else if (IsKernelActor(input_node, graph_compiler_info.strategy_)) {
|
||||
// The actor input is a cnode.
|
||||
if (front_node_to_actor_.find(input_node) == front_node_to_actor_.end()) {
|
||||
const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
|
||||
const auto &backend_node =
|
||||
graph_compiler_info.control_node_parser_->GetBackendKernelByFrontKernel(kernel_with_index);
|
||||
if (backend_node.first == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find actor:" << to_actor->GetAID()
|
||||
<< " input_node:" << AnfAlgo::GetNodeDebugString(input_node) << " addr:" << input_node;
|
||||
}
|
||||
const auto &actor_name = backend_node.first->fullname_with_scope();
|
||||
const auto &actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto from_actor = dynamic_cast<KernelActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
||||
auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index);
|
||||
(void)from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, backend_node.second, false);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
return;
|
||||
}
|
||||
|
||||
auto op_arrow = std::make_shared<DataArrow>(input_with_index.second, to_actor->GetAID(), to_index);
|
||||
auto from_actor = front_node_to_actor_[input_node];
|
||||
(void)from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, input_with_index.second, false);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
} else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) {
|
||||
// The actor input is a parameter in host data source actor.
|
||||
std::string actor_name = graph_compiler_info.name_ + "_HostDSActor";
|
||||
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto from_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
||||
auto backend_iter = front_to_backend_parameter.find(input_node);
|
||||
if (backend_iter == front_to_backend_parameter.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(input_node);
|
||||
}
|
||||
|
||||
const auto &backend_node = backend_iter->second.first;
|
||||
auto iter = from_actor->data_node_position_map_.find(input_node);
|
||||
if (iter == from_actor->data_node_position_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find data node in data source actor, backend node:"
|
||||
<< AnfAlgo::GetNodeDebugString(backend_node)
|
||||
<< " front node:" << AnfAlgo::GetNodeDebugString(input_node);
|
||||
}
|
||||
|
||||
auto op_arrow = std::make_shared<DataArrow>(iter->second, to_actor->GetAID(), to_index);
|
||||
(void)from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->data_nodes_[iter->second], 0, false);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node)
|
||||
<< " to actor:" << to_actor->GetAID();
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
SwitchActor *const actor) {
|
||||
// Link switch input.
|
||||
const auto &inputs = actor->input_nodes_;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto input = inputs[i];
|
||||
if (input.first->isa<ValueNode>() || (input.first->isa<Parameter>() && HasAbstractRef(input.first))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const FuncGraphPtr from_func_graph = actor->node_->func_graph();
|
||||
LinkDataArrowByControlNode(graph_compiler_info, input, from_func_graph, actor, i);
|
||||
}
|
||||
|
||||
// Link switch output.
|
||||
for (size_t i = 0; i < actor->branch_func_graph_.size(); ++i) {
|
||||
auto func_graph = actor->branch_func_graph_[i];
|
||||
if (func_graph == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto gather_name = func_graph->ToString();
|
||||
if (actor_name_to_actor_.find(gather_name) == actor_name_to_actor_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find gather actor for funcgraph:" << gather_name
|
||||
<< ",switch input size:" << actor->input_nodes_.size();
|
||||
}
|
||||
auto to_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[gather_name]);
|
||||
for (size_t j = 0; j < actor->branch_inputs_pos_[i].size(); ++j) {
|
||||
auto pos = actor->branch_inputs_pos_[i][j];
|
||||
auto to_actor_index = j;
|
||||
auto op_arrow = std::make_shared<DataArrow>(pos, to_actor->GetAID(), to_actor_index);
|
||||
(void)actor->output_branch_arrows_[i].emplace_back(op_arrow);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *const kernel_actors,
|
||||
const std::vector<KernelGraphPtr> &graphs,
|
||||
const ControlNodeParserPtr &parser) {
|
||||
// Link control arrow to kernel actor.
|
||||
for (size_t i = 0; i < graphs.size(); ++i) {
|
||||
const auto &kernel_graph = graphs[i];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
const auto &func_graph = kernel_graph->GetFuncGraph();
|
||||
if (func_graph == nullptr) {
|
||||
continue;
|
||||
}
|
||||
const auto &actor = FetchActor(func_graph->ToString());
|
||||
if (actor == nullptr) {
|
||||
continue;
|
||||
}
|
||||
const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
|
||||
// When gather actor is not empty, it means the control arrow of no input kernel actor needs to be sent by gather.
|
||||
for (const auto &kernel : kernel_graph->execution_order()) {
|
||||
if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) {
|
||||
const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
|
||||
if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
|
||||
(void)gather_actor->output_control_arrows_.emplace_back(kernel_actor->GetAID());
|
||||
kernel_actor->input_controls_num_ = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &kernel_actor : *kernel_actors) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
|
||||
if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
|
||||
!parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) {
|
||||
// Check whether the kernel actor belongs to the root graph.
|
||||
// In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output
|
||||
// should be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be
|
||||
// sent to the output switch actor.
|
||||
const auto &graph = kernel_actor->kernel_->func_graph();
|
||||
OpActor<DeviceTensor> *actor = nullptr;
|
||||
|
||||
if (graph != nullptr) {
|
||||
const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
|
||||
const auto func_graph = kernel_graph->GetFuncGraph();
|
||||
if (func_graph != nullptr) {
|
||||
actor = FetchActor(func_graph->get_return()->DebugString());
|
||||
if (actor != nullptr) {
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
(void)kernel_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
|
||||
switch_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Link input auto monad control arrow from kernel actor to gather actor.
|
||||
const auto &monad_nodes = parser->kernel_to_call_nodes_;
|
||||
for (const auto node_pair : monad_nodes) {
|
||||
const auto &kernel_actor_name = node_pair.first->fullname_with_scope();
|
||||
const auto &gather_actor_name = node_pair.second->DebugString();
|
||||
auto kernel_op_actor = FetchActor(kernel_actor_name);
|
||||
auto gather_op_actor = FetchActor(gather_actor_name);
|
||||
if (kernel_op_actor == nullptr || gather_op_actor == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto kernel_actor = dynamic_cast<KernelActor *>(kernel_op_actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(gather_op_actor);
|
||||
(void)kernel_actor->output_control_arrows_.emplace_back(gather_actor->GetAID());
|
||||
gather_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *const switch_actors,
|
||||
LoopCountActor *const to_actor,
|
||||
const KernelMapPosition &origin_outputs_order) {
|
||||
if (to_actor == nullptr || (*switch_actors).empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If there is no output from the switch actor branch, it means that the subgraph has no input,
|
||||
// and need to connect a control arrow to the corresponding gather actor.
|
||||
for (auto &switch_actor : (*switch_actors)) {
|
||||
if (AnfAlgo::CheckPrimitiveType(switch_actor->node_, prim::kPrimReturn)) {
|
||||
const auto &func_graph = switch_actor->node_->func_graph();
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
const auto &actor_name = func_graph->ToString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
(void)gather_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
|
||||
switch_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < switch_actor->output_branch_arrows_.size(); ++i) {
|
||||
const auto &arrows = switch_actor->output_branch_arrows_[i];
|
||||
if (arrows.empty() && switch_actor->branch_func_graph_[i] != nullptr) {
|
||||
const auto &actor_name = switch_actor->branch_func_graph_[i]->ToString();
|
||||
const auto &actor = FetchActor(actor_name);
|
||||
if (actor != nullptr) {
|
||||
const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
(void)switch_actor->output_branch_control_arrows_[i].emplace_back(gather_actor->GetAID());
|
||||
gather_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all the call node in outputs.
|
||||
std::set<AnfNodePtr> call_nodes;
|
||||
for (const auto &output : origin_outputs_order) {
|
||||
if (IsCallNode(output.first.first)) {
|
||||
(void)call_nodes.insert(output.first.first);
|
||||
}
|
||||
}
|
||||
to_actor->input_controls_num_ += call_nodes.size();
|
||||
|
||||
// Link the output switch actor of the subgraph to the output actor.
|
||||
for (const auto &call_node : call_nodes) {
|
||||
const auto &func_graphs = FetchFuncGraphbyCallNode(call_node);
|
||||
for (const auto func_graph : func_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &actor_name = func_graph->get_return()->DebugString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
size_t branch_index = switch_actor->branch_id_to_index_.size();
|
||||
if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
|
||||
branch_index = switch_actor->branch_id_to_index_[kMainBranchID];
|
||||
} else {
|
||||
switch_actor->branch_id_to_index_[kMainBranchID] = branch_index;
|
||||
}
|
||||
|
||||
(void)switch_actor->output_branch_control_arrows_[branch_index].emplace_back(to_actor->GetAID());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
for (const auto &control_node : graph_compiler_info.control_nodes_) {
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
|
||||
const auto &actor_name = control_node->DebugString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) {
|
||||
const auto &func_graph = switch_actor->branch_func_graph_[i];
|
||||
if (func_graph == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &gather_actor = FetchActor(func_graph->ToString());
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
(void)switch_actor->output_branch_branch_arrows_[i].emplace_back(gather_actor->GetAID());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
if (graph_compiler_info.control_nodes_.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Link branch arrow from gather actor to gather actor.
|
||||
for (const auto &control_node : graph_compiler_info.control_nodes_) {
|
||||
const auto &cnode = control_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
|
||||
const auto &actor_name = control_node->DebugString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
(void)gather_actor->output_branch_arrows_.emplace_back(gather_actor->gather_aid_);
|
||||
}
|
||||
}
|
||||
|
||||
// Link branch arrow from gather actor to switch actor.
|
||||
for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) {
|
||||
const auto &actor_name = func_graph_with_call_num.first->ToString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
(void)gather_actor->output_branch_arrows_.emplace_back(gather_actor->switch_aid_);
|
||||
}
|
||||
}
|
||||
const size_t branch_index) {}
|
||||
|
||||
bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionStrategy strategy) const {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
|
@ -2681,90 +1955,11 @@ void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compil
|
|||
void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
|
||||
ofs << "\t\tactor input num:" << actor->data_nodes_.size() << "\n";
|
||||
for (const auto &node : actor->data_nodes_) {
|
||||
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << "\tindex:" << node.second << '\n';
|
||||
}
|
||||
|
||||
ofs << "\t\tactor front to backend node:\n";
|
||||
for (const auto &front_to_backend_parameter : actor->front_to_backend_parameter_) {
|
||||
ofs << "\t\t\tfront node:" << AnfAlgo::GetNodeDebugString(front_to_backend_parameter.first) << '\n';
|
||||
for (const auto node_with_index : front_to_backend_parameter.second) {
|
||||
ofs << "\t\t\t\tbackend node:" << AnfAlgo::GetNodeDebugString(node_with_index.first)
|
||||
<< "\tindex:" << node_with_index.second << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output data arrow:\n";
|
||||
for (const auto &data_arrow : actor->output_data_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output result arrow:\n";
|
||||
for (const auto &result_arrow : actor->output_result_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name() << "\tto_input_index:" << result_arrow->to_input_index_
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output control arrow:\n";
|
||||
for (const auto &control_arrow : actor->output_control_arrows_) {
|
||||
ofs << "\t\t\tto_actor_name:" << control_arrow;
|
||||
}
|
||||
ofs << "\n";
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
|
||||
ofs << "\t\tactor input num:" << actor->input_nodes_.size() << "\n";
|
||||
for (const auto &node : actor->input_nodes_) {
|
||||
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << '\t' << node.second << '\n';
|
||||
}
|
||||
|
||||
ofs << "\t\tactor input pos:\n";
|
||||
for (size_t i = 0; i < actor->branch_inputs_pos_.size(); ++i) {
|
||||
ofs << "\t\t\tbranch " << i << " input pos:";
|
||||
for (const auto pos : actor->branch_inputs_pos_[i]) {
|
||||
ofs << pos << '\t';
|
||||
}
|
||||
ofs << '\n';
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output data arrow:\n";
|
||||
for (size_t i = 0; i < actor->output_branch_arrows_.size(); ++i) {
|
||||
ofs << "\t\t\tbranch " << i << " output data:\n";
|
||||
for (const auto arrow : actor->output_branch_arrows_[i]) {
|
||||
MS_EXCEPTION_IF_NULL(arrow);
|
||||
ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
|
||||
<< "\tto_input_index:" << arrow->to_input_index_ << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output result arrow:\n";
|
||||
for (size_t i = 0; i < actor->output_branch_result_arrows_.size(); ++i) {
|
||||
ofs << "\t\t\tbranch " << i << " output result:\n";
|
||||
for (const auto arrow : actor->output_branch_result_arrows_[i]) {
|
||||
MS_EXCEPTION_IF_NULL(arrow);
|
||||
ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
|
||||
<< "\tto_input_index:" << arrow->to_input_index_ << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output control arrow:\n";
|
||||
for (size_t i = 0; i < actor->output_branch_control_arrows_.size(); ++i) {
|
||||
ofs << "\t\t\tbranch " << i << " output control:\n";
|
||||
for (const auto arrow : actor->output_branch_control_arrows_[i]) {
|
||||
ofs << "\t\t\t\t from index:" << arrow << '\n';
|
||||
}
|
||||
}
|
||||
ofs << "\n";
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,9 +33,9 @@
|
|||
#include "runtime/framework/actor/loop_count_actor.h"
|
||||
#include "runtime/framework/actor/kernel_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "runtime/framework/actor/gather_actor.h"
|
||||
#include "runtime/framework/actor/copy_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/switch_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/gather_actor.h"
|
||||
#include "thread/actor_threadpool.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -131,8 +131,6 @@ class GraphScheduler {
|
|||
const std::vector<DataSourceActorPtr> &data_source_actors,
|
||||
const HostTensorQueuePtr &host_queue);
|
||||
std::vector<KernelActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);
|
||||
std::vector<SwitchActorPtr> BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
std::vector<GatherActorPtr> BuildGatherActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
||||
// Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
|
||||
// previous graph and the head of next graph.
|
||||
|
@ -195,34 +193,10 @@ class GraphScheduler {
|
|||
void LinkDataArrowForGatherActor(GatherActor *const from_actor, KernelActor *const to_actor,
|
||||
const KernelWithIndex &front_node_with_index,
|
||||
const KernelWithIndex &to_node_with_index);
|
||||
void LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, SwitchActor *const actor);
|
||||
// Connect the input of the actor.
|
||||
void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const KernelWithIndex &input_node,
|
||||
const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *const to_actor,
|
||||
const size_t to_index);
|
||||
// When the input of the actor is a call node, the output of the funcgraph called by the call node needs to be
|
||||
// connected.
|
||||
void LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index, const ControlNodeParserPtr &parser,
|
||||
const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *const to_actor,
|
||||
const size_t to_index);
|
||||
void LinkDataArrowForSwitchActor(SwitchActor *const from_actor, const size_t from_index,
|
||||
OpActor<DeviceTensor> *const to_actor, const size_t to_index,
|
||||
const size_t branch_index = SIZE_MAX);
|
||||
|
||||
void LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *const kernel_actors,
|
||||
const std::vector<KernelGraphPtr> &graphs, const ControlNodeParserPtr &parser);
|
||||
|
||||
void LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *const switch_actors, LoopCountActor *const to_actor,
|
||||
const KernelMapPosition &origin_outputs_order);
|
||||
// In control flow, there are scenarios where there are multi-branch outputs, and the gather actor needs to
|
||||
// send the branch id to the loop count actor.
|
||||
void LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
// Add input for switch actor. Since part of the input of funcgraph is on call node, these inputs need to be added
|
||||
// to switch actor.
|
||||
void PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// Check whether the actor set is valid.
|
||||
bool CheckActorValid(const ActorSet *actor_set,
|
||||
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline) const;
|
||||
|
|
Loading…
Reference in New Issue