forked from mindspore-Ecosystem/mindspore
Unified runtime support mixed precision.
This commit is contained in:
parent
f86d707126
commit
bf6528645c
|
@ -612,7 +612,12 @@ AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, Kernel
|
|||
if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
|
||||
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
|
||||
}
|
||||
for (const auto ¶meter : parameters) {
|
||||
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
const auto ¶meter = parameters[i];
|
||||
// In control flow, if the input of the cnode is a call node, it will be processed as a make_tuple input,
|
||||
// which needs to be linked when processing the internal node.
|
||||
graph->CacheInternalParameterToFrontNode(parameter, {node, i});
|
||||
auto valid_inputs = graph->MutableValidInputs();
|
||||
MS_EXCEPTION_IF_NULL(valid_inputs);
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
|
|
|
@ -44,6 +44,8 @@ bool IsDeviceQueueDSActor(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsSwitchActor(const AnfNodePtr &node) { return AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch); }
|
||||
|
||||
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph, const TensorPtr &tensor,
|
||||
const std::vector<AnfNodePtr> &host_parameters) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
|
|
@ -59,6 +59,7 @@ bool IsDeviceQueueDSActor(const AnfNodePtr &node);
|
|||
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph = nullptr,
|
||||
const TensorPtr &tensor = nullptr, const std::vector<AnfNodePtr> &host_parameters = {});
|
||||
bool IsKernelActor(const AnfNodePtr &node);
|
||||
bool IsSwitchActor(const AnfNodePtr &node);
|
||||
// The skip kernel doesn't run, it exists in the inplace optimizer.
|
||||
bool IsSkippedKernelActor(const AnfNodePtr &node);
|
||||
|
||||
|
|
|
@ -58,8 +58,8 @@ void GatherActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
|
|||
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
input_op_datas_.erase(context->sequential_num_);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -86,8 +86,8 @@ void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph,
|
|||
// 2. Output the kernel actor.
|
||||
for (const auto parameters : func_iter->second) {
|
||||
if (parameters.size() != graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Parameters num is invalid, current:" << parameters.size()
|
||||
<< " need:" << graph_inputs.size();
|
||||
MS_LOG(EXCEPTION) << "Parameters num is invalid, current:" << parameters.size() << " need:" << graph_inputs.size()
|
||||
<< " func_graph:" << func_iter->first->ToString();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
|
@ -133,20 +133,11 @@ void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph,
|
|||
void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
// Send output data.
|
||||
for (auto &output_data : output_data_) {
|
||||
MS_EXCEPTION_IF_NULL(output_data);
|
||||
Async(output_data->op_id_, &OpActor::RunOpData, output_data, context);
|
||||
}
|
||||
|
||||
// Send output control.
|
||||
auto source_aid = const_cast<AID *>(&GetAID());
|
||||
for (auto &output_control : output_control_arrows_) {
|
||||
Async(output_control, &OpActor::RunOpControl, source_aid, context);
|
||||
}
|
||||
|
||||
// Branch arrow and result arrow must be executed before the data arrow and control arrow, otherwise the output
|
||||
// actor may receive the loop count message first and cause the output to be abnormal.
|
||||
if (branch_id_ > kInvalidBranchID) {
|
||||
Async(loop_count_aid_, &LoopCountActor::CollectBranchId, branch_id_, context);
|
||||
Async(output_aid_, &OutputActor::CollectBranchId, branch_id_, context);
|
||||
}
|
||||
|
||||
// Send graph output result.
|
||||
|
@ -160,9 +151,22 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
|||
input_device_tensors_[from_index]) {
|
||||
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
|
||||
result_arrow->to_input_index_, context);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send output data.
|
||||
for (auto &output_data : output_data_) {
|
||||
MS_EXCEPTION_IF_NULL(output_data);
|
||||
Async(output_data->op_id_, &OpActor::RunOpData, output_data, context);
|
||||
}
|
||||
|
||||
// Send output control.
|
||||
auto source_aid = const_cast<AID *>(&GetAID());
|
||||
for (auto &output_control : output_control_arrows_) {
|
||||
Async(output_control, &OpActor::RunOpControl, source_aid, context);
|
||||
}
|
||||
}
|
||||
|
||||
void GatherActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
|
||||
|
@ -209,5 +213,23 @@ bool GatherActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
void GatherActor::EraseInput(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto ret = input_op_datas_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input data failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
auto ret = input_op_controls_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input controls failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,8 +37,9 @@ namespace runtime {
|
|||
// Gather actor is the entrance of sub funcgraph. Graph input is sent to it and sent to other actors by gather actor.
|
||||
class GatherActor : public OpActor<DeviceTensor> {
|
||||
public:
|
||||
GatherActor(const std::string &name, const std::vector<AnfNodePtr> ¶meters, const AID loop_count_aid)
|
||||
: OpActor(name), data_nodes_(parameters), loop_count_aid_(loop_count_aid) {}
|
||||
GatherActor(const std::string &name, const std::vector<AnfNodePtr> ¶meters, const AID loop_count_aid,
|
||||
const AID output_aid)
|
||||
: OpActor(name), data_nodes_(parameters), loop_count_aid_(loop_count_aid), output_aid_(output_aid) {}
|
||||
~GatherActor() override = default;
|
||||
|
||||
// Get the index of the parameter, the data_node needs to be the front node.
|
||||
|
@ -61,6 +62,8 @@ class GatherActor : public OpActor<DeviceTensor> {
|
|||
// Check whether satisfy the condition for launch.
|
||||
bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const;
|
||||
void SendOutput(OpContext<DeviceTensor> *context) const;
|
||||
// Erase input data and input controls when finish gather launch.
|
||||
void EraseInput(OpContext<DeviceTensor> *context);
|
||||
|
||||
// The device tensors for launch.
|
||||
std::vector<DeviceTensor *> input_device_tensors_;
|
||||
|
@ -83,6 +86,7 @@ class GatherActor : public OpActor<DeviceTensor> {
|
|||
size_t input_controls_num_{0};
|
||||
|
||||
const AID loop_count_aid_;
|
||||
const AID output_aid_;
|
||||
|
||||
// Cache unique output data by output index to modify the output data effectively.
|
||||
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;
|
||||
|
|
|
@ -30,8 +30,8 @@ void LoopCountActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *c
|
|||
auto sequential_num = context->sequential_num_;
|
||||
input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
|
||||
if (CheckExecuteCondition(context)) {
|
||||
Execute(context);
|
||||
if (CheckLoopCountIncreaseCondition(context)) {
|
||||
IncreaseLoopCount(context);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -39,8 +39,8 @@ void LoopCountActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
branch_id_ = branch_id;
|
||||
|
||||
if (CheckExecuteCondition(context)) {
|
||||
Execute(context);
|
||||
if (CheckLoopCountIncreaseCondition(context)) {
|
||||
IncreaseLoopCount(context);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ void LoopCountActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
|
|||
SendOutput(context);
|
||||
}
|
||||
|
||||
void LoopCountActor::Execute(OpContext<DeviceTensor> *context) {
|
||||
void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
auto ret = input_op_controls_.erase(sequential_num);
|
||||
|
@ -100,7 +100,7 @@ void LoopCountActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
bool LoopCountActor::CheckExecuteCondition(OpContext<DeviceTensor> *context) {
|
||||
bool LoopCountActor::CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
if (branch_id_ == kInvalidBranchID) {
|
||||
|
@ -108,8 +108,7 @@ bool LoopCountActor::CheckExecuteCondition(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
|
||||
if (branch_id_ >= SizeToInt(branch_id_to_input_controls_num_.size())) {
|
||||
MS_LOG(ERROR) << "Branch id is invalid, id:" << branch_id_
|
||||
<< " total branch num:" << branch_id_to_input_controls_num_.size();
|
||||
MS_LOG(ERROR) << "Branch id is invalid, id:" << branch_id_;
|
||||
}
|
||||
return input_op_controls_[sequential_num].size() == branch_id_to_input_controls_num_[branch_id_];
|
||||
}
|
||||
|
|
|
@ -60,10 +60,10 @@ class LoopCountActor : public DebugAwareActor {
|
|||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
void Execute(OpContext<DeviceTensor> *context);
|
||||
void IncreaseLoopCount(OpContext<DeviceTensor> *context);
|
||||
void SendOutput(OpContext<DeviceTensor> *context);
|
||||
|
||||
bool CheckExecuteCondition(OpContext<DeviceTensor> *context);
|
||||
bool CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *context);
|
||||
// The loop count is constant, the current count is increased after each step running finished.
|
||||
size_t loop_count_;
|
||||
size_t current_count_;
|
||||
|
@ -87,7 +87,7 @@ class LoopCountActor : public DebugAwareActor {
|
|||
|
||||
// When the result of the graph is sent to the output actor, the gather actor of the graph needs
|
||||
// to send branch_id to the output actor to determine the corresponding weight.
|
||||
int branch_id_{kInvalidBranchID};
|
||||
int branch_id_{kMainBranchID};
|
||||
};
|
||||
|
||||
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;
|
||||
|
|
|
@ -46,17 +46,21 @@ TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index,
|
|||
|
||||
void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (branch_id_ == kInvalidBranchID) {
|
||||
MS_LOG(EXCEPTION) << "Invalid branch id for output actor.";
|
||||
}
|
||||
current_count_ = loop_count;
|
||||
if (loop_count_ == current_count_) {
|
||||
if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) {
|
||||
std::string error_info = "The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) +
|
||||
", the current outputs num: " + std::to_string(current_outputs_num_) +
|
||||
", the device tensor store num: " + std::to_string(device_tensor_store_keys_.size());
|
||||
if (current_outputs_num_ + device_tensor_store_keys_[branch_id_].size() != outputs_num_) {
|
||||
std::string error_info =
|
||||
"The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) +
|
||||
", the current outputs num: " + std::to_string(current_outputs_num_) +
|
||||
", the device tensor store num: " + std::to_string(device_tensor_store_keys_[branch_id_].size());
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
||||
// Because device tensor store can't send data, so fetch the output result of device tensor store in running end.
|
||||
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
|
||||
for (const auto &device_tensor_store_key : device_tensor_store_keys_[branch_id_]) {
|
||||
if (device_tensor_store_key.first >= outputs_.size()) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range.");
|
||||
}
|
||||
|
@ -96,6 +100,11 @@ void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *c
|
|||
}
|
||||
}
|
||||
|
||||
void OutputActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
branch_id_ = branch_id;
|
||||
}
|
||||
|
||||
void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
|
||||
OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
|
@ -115,7 +124,7 @@ void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_ind
|
|||
// Save the output nodes to clear the device tensor in the running end.
|
||||
output_nodes_[output_position] = KernelWithIndex(output_node, output_index);
|
||||
// There is no loop count actor in step mode, need trigger call CollectLoopCount to replace old output device tensors.
|
||||
if (!need_loop_count_ && (current_outputs_num_ + device_tensor_store_keys_.size() == outputs_num_)) {
|
||||
if (!need_loop_count_ && (current_outputs_num_ + device_tensor_store_keys_[branch_id_].size() == outputs_num_)) {
|
||||
CollectLoopCount(++current_count_, context);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include "runtime/framework/control_node_parser.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
|
@ -47,6 +49,7 @@ class OutputActor : public OpActor<DeviceTensor> {
|
|||
outputs_.resize(outputs_num);
|
||||
output_nodes_.resize(outputs_num);
|
||||
device_contexts_.resize(outputs_num);
|
||||
device_tensor_store_keys_[kMainBranchID] = std::vector<std::pair<size_t, AnfNodePtr>>();
|
||||
}
|
||||
~OutputActor() override = default;
|
||||
|
||||
|
@ -57,6 +60,8 @@ class OutputActor : public OpActor<DeviceTensor> {
|
|||
void CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
|
||||
OpContext<DeviceTensor> *context);
|
||||
|
||||
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context);
|
||||
|
||||
std::vector<TensorPtr> &outputs() { return outputs_; }
|
||||
|
||||
private:
|
||||
|
@ -74,9 +79,13 @@ class OutputActor : public OpActor<DeviceTensor> {
|
|||
size_t outputs_num_;
|
||||
size_t current_outputs_num_;
|
||||
bool need_loop_count_;
|
||||
int branch_id_{kMainBranchID};
|
||||
|
||||
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
|
||||
std::vector<std::pair<size_t, AnfNodePtr>> device_tensor_store_keys_;
|
||||
// Pair<branch_id, <index, node>> points to the dependent device tensor store, branch_id is the output branch id.
|
||||
// In general, the branch id is 0, which means there is only one output branch in the actor set. When there are
|
||||
// multiple possible output branches in the actor set, different branch ids correspond to their own related nodes.
|
||||
// The index is the position of node in the output, node is the key of the device tensor store.
|
||||
std::unordered_map<size_t, std::vector<std::pair<size_t, AnfNodePtr>>> device_tensor_store_keys_;
|
||||
};
|
||||
|
||||
using OutputActorPtr = std::shared_ptr<OutputActor>;
|
||||
|
|
|
@ -44,6 +44,7 @@ void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
|
|||
// When all the inputs are collected, then allocate memory and callback launch.
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
}
|
||||
|
@ -238,6 +239,25 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
void SwitchActor::EraseInput(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto ret = input_op_datas_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input data failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
auto ret = input_op_controls_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input controls failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context);
|
||||
}
|
||||
|
|
|
@ -87,6 +87,8 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
// Fetch the args of switch branch.
|
||||
void FetchInputDeviceTensor(OpContext<DeviceTensor> *context);
|
||||
void SendOutput(OpContext<DeviceTensor> *context);
|
||||
// Erase input data and input controls when finish switch launch.
|
||||
void EraseInput(OpContext<DeviceTensor> *context);
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *context);
|
||||
|
||||
// All inputs of the switch actor, excluding weight and tensor.
|
||||
|
@ -107,6 +109,8 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
const AID memory_manager_aid_;
|
||||
// The dependent input data number.
|
||||
size_t input_datas_num_{0};
|
||||
// The dependent input controls number.
|
||||
size_t input_controls_num_{0};
|
||||
CNodePtr node_;
|
||||
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
|
|
|
@ -983,12 +983,15 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
|
|||
}
|
||||
}
|
||||
|
||||
auto loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor";
|
||||
auto actor = FetchActor(loop_count_actor_name);
|
||||
if (actor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find loop count actor by name:" << loop_count_actor_name;
|
||||
}
|
||||
auto gather_actor = std::make_shared<GatherActor>(actor_name, parameters, actor->GetAID());
|
||||
const auto &loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor";
|
||||
const auto &loop_count_actor = FetchActor(loop_count_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(loop_count_actor);
|
||||
const auto &output_actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
|
||||
const auto &output_actor = FetchActor(output_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(output_actor);
|
||||
|
||||
auto gather_actor =
|
||||
std::make_shared<GatherActor>(actor_name, parameters, loop_count_actor->GetAID(), output_actor->GetAID());
|
||||
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.origin_parameters_order_,
|
||||
graph_compiler_info.front_to_backend_parameters_,
|
||||
graph_compiler_info.func_graph_to_parameters_, front_to_backend_kernel);
|
||||
|
@ -1066,7 +1069,7 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
|
|||
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_output_node.get());
|
||||
return;
|
||||
}
|
||||
if (graph_output_to_actor_.count(front_output_with_index) == 0) {
|
||||
if (graph_output_to_actor_.count(front_output_with_index) == 0 && (!IsSwitchActor(front_output_node))) {
|
||||
MS_LOG(EXCEPTION) << "Can't find actor by front node:" << front_output_node->fullname_with_scope()
|
||||
<< ", internal parameter:" << internal_parameter->fullname_with_scope();
|
||||
}
|
||||
|
@ -1076,6 +1079,9 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
|
|||
auto from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor_pair.first);
|
||||
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_kernel_, actor_pair.second);
|
||||
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else if (IsSwitchActor(front_output_node)) {
|
||||
const auto &from_actor = dynamic_cast<SwitchActor *>(FetchActor(front_output_node->fullname_with_scope()));
|
||||
MS_LOG(ERROR) << "Need link to switch actor:" << from_actor->GetAID();
|
||||
} else if (IsKernelActor(front_output_node)) {
|
||||
auto from_actor = dynamic_cast<KernelActor *>(actor_pair.first);
|
||||
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->kernel_, actor_pair.second);
|
||||
|
@ -1447,13 +1453,15 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
|||
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
|
||||
continue;
|
||||
}
|
||||
to_actor->device_contexts_[iter->second] = graph_compiler_info.device_contexts_[number - 1];
|
||||
|
||||
to_actor->device_contexts_[iter->second.second] = graph_compiler_info.device_contexts_[number - 1];
|
||||
// The device tensor of graph out need be taken over by host tensor, so set the max reference count.
|
||||
UpdateRefCount(output_with_index.first, output_with_index.second, true);
|
||||
|
||||
// The graph output is from device tensor store.
|
||||
if (IsPersistentDeviceTensor(output_with_index.first)) {
|
||||
to_actor->device_tensor_store_keys_.emplace_back(iter->second, output_with_index.first);
|
||||
to_actor->device_tensor_store_keys_[iter->second.first].emplace_back(iter->second.second,
|
||||
output_with_index.first);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1462,7 +1470,7 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
|||
const auto &from_actor =
|
||||
dynamic_cast<KernelActor *>(FetchActor(output_with_index.first->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), iter->second);
|
||||
auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), iter->second.second);
|
||||
from_actor->output_result_arrows_.emplace_back(op_arrow);
|
||||
continue;
|
||||
}
|
||||
|
@ -1492,7 +1500,7 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
|||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto op_arrow = std::make_shared<DataArrow>(from_actor_output_index, to_actor->GetAID(), iter->second);
|
||||
auto op_arrow = std::make_shared<DataArrow>(from_actor_output_index, to_actor->GetAID(), iter->second.second);
|
||||
from_actor->output_result_arrows_.emplace_back(op_arrow);
|
||||
}
|
||||
}
|
||||
|
@ -1644,7 +1652,8 @@ void GraphScheduler::LinkDataArrowByCallInput(const GraphCompilerInfo &graph_com
|
|||
auto from_actor = iter->second;
|
||||
auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), to_index);
|
||||
from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, output_with_index.second, false);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
} else if (output_with_index.first->isa<Parameter>()) {
|
||||
// Input is a parameter from gather actor.
|
||||
const auto &actor_name = func_graph->ToString();
|
||||
|
@ -1693,7 +1702,8 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
|
|||
auto op_arrow = std::make_shared<DataArrow>(input_witch_index.second, to_actor->GetAID(), to_index);
|
||||
auto from_actor = front_node_to_actor_[input_witch_index.first];
|
||||
from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
UpdateRefCount(from_actor->kernel_, input_witch_index.second);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, input_witch_index.second, false);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
} else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) {
|
||||
// The actor input is a parameter in host data source actor.
|
||||
std::string actor_name = graph_compiler_info.name_ + "_HostDSActor";
|
||||
|
@ -1717,7 +1727,8 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
|
|||
|
||||
auto op_arrow = std::make_shared<DataArrow>(iter->second, to_actor->GetAID(), to_index);
|
||||
from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
UpdateRefCount(from_actor->data_nodes_[iter->second], 0);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->data_nodes_[iter->second], 0, false);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node);
|
||||
}
|
||||
|
@ -1737,9 +1748,15 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_
|
|||
if (func_graph == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
actor->AddInput(func_graph->output(), 0);
|
||||
}
|
||||
|
||||
auto gather_name = func_graph->ToString();
|
||||
if (actor_name_to_actor_.find(gather_name) == actor_name_to_actor_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find gather actor for funcgraph:" << gather_name;
|
||||
MS_LOG(EXCEPTION) << "Cannot find gather actor for funcgraph:" << gather_name
|
||||
<< ",switch input size:" << actor->input_nodes_.size();
|
||||
}
|
||||
auto to_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[gather_name]);
|
||||
for (size_t j = 0; j < actor->branch_inputs_pos_[i].size(); ++j) {
|
||||
|
@ -1771,17 +1788,20 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
|
|||
if (graph_compiler_info.control_nodes_.empty()) {
|
||||
return;
|
||||
}
|
||||
const auto func_graph = graph_compiler_info.control_nodes_[0]->func_graph();
|
||||
|
||||
const auto func_graph = graph_compiler_info.control_nodes_[0]->func_graph();
|
||||
const auto &loop_count_actor = actor_set->loop_count_actor_.get();
|
||||
const auto &output_actor = actor_set->output_actor_.get();
|
||||
|
||||
// If there is only one branch output, set the branch id of the loop count to 0, no need to send the branch id.
|
||||
auto outputs = graph_compiler_info.front_output_nodes_;
|
||||
if (outputs.size() == 1) {
|
||||
loop_count_actor->branch_id_ = kMainBranchID;
|
||||
return;
|
||||
}
|
||||
|
||||
loop_count_actor->branch_id_ = kInvalidBranchID;
|
||||
output_actor->branch_id_ = kInvalidBranchID;
|
||||
|
||||
std::vector<FuncGraphPtr> output_func_graphs;
|
||||
for_each(outputs.begin(), outputs.end(),
|
||||
[&output_func_graphs](const AnfNodePtr &output) { output_func_graphs.push_back(output->func_graph()); });
|
||||
|
@ -1866,8 +1886,11 @@ void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo
|
|||
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "Link output node:" << AnfAlgo::GetNodeDebugString(origin_output_with_index.first)
|
||||
<< " branch id:" << iter->second.first << " index:" << iter->second.second
|
||||
<< " for gather actor:" << gather_actor->GetAID();
|
||||
|
||||
auto op_arrow = std::make_shared<DataArrow>(i, to_actor->GetAID(), iter->second);
|
||||
auto op_arrow = std::make_shared<DataArrow>(i, to_actor->GetAID(), iter->second.second);
|
||||
gather_actor->output_result_arrows_.emplace_back(op_arrow);
|
||||
const auto &backend_nodes = gather_actor->front_to_backend_parameter_[front_node];
|
||||
if (backend_nodes.empty()) {
|
||||
|
@ -1882,13 +1905,14 @@ void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo
|
|||
MS_LOG(EXCEPTION) << "Cannot find backend node in host data source actor, node:"
|
||||
<< AnfAlgo::GetNodeDebugString(backend_node);
|
||||
}
|
||||
to_actor->device_contexts_[iter->second] = host_ds_actor->device_contexts_[node_iter - data_nodes.begin()];
|
||||
to_actor->device_contexts_[iter->second.second] =
|
||||
host_ds_actor->device_contexts_[node_iter - data_nodes.begin()];
|
||||
} else {
|
||||
auto actor_base = FetchActor(backend_node->fullname_with_scope());
|
||||
MS_EXCEPTION_IF_NULL(actor_base);
|
||||
auto kernel_actor = dynamic_cast<KernelActor *>(actor_base);
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
to_actor->device_contexts_[iter->second] = kernel_actor->device_context_;
|
||||
to_actor->device_contexts_[iter->second.second] = kernel_actor->device_context_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2212,8 +2236,8 @@ void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &of
|
|||
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
|
||||
<< "\toutputs_num:" << actor->outputs_num_ << "\n";
|
||||
|
||||
ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.size() << "\n ";
|
||||
for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
|
||||
ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.at(kMainBranchID).size() << "\n ";
|
||||
for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_.at(kMainBranchID)) {
|
||||
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
|
||||
ofs << "\t\t\toutput_node_position:" << device_tensor_store_key.first
|
||||
<< "\toutput_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";
|
||||
|
|
|
@ -42,7 +42,10 @@ namespace runtime {
|
|||
using mindspore::device::DeviceContext;
|
||||
using mindspore::session::KernelGraph;
|
||||
using mindspore::session::KernelWithIndex;
|
||||
using KernelMapPosition = std::map<KernelWithIndex, size_t, session::KernelWithIndexCmp>;
|
||||
// Position of kernel with index, the value pair<branch_id, pos> means the branch id of the kernel and the pos of
|
||||
// the kernel. Generally, there is only one branch, and the branch id is 0 at this time. In control flow, there
|
||||
// are multiple branch scenarios, and pos represents the position of the kernel in the branch.
|
||||
using KernelMapPosition = std::map<KernelWithIndex, std::pair<int, size_t>, session::KernelWithIndexCmp>;
|
||||
using ActorInfo = std::string;
|
||||
|
||||
// The second element of pair represents the output index of op actor corresponding to the graph output node.
|
||||
|
|
|
@ -512,12 +512,18 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
|||
runtime::KernelMapPosition outputs_order;
|
||||
size_t outputs_num = 0;
|
||||
const auto &all_branch_output = ControlNodeParser::FetchAllBranchOutputs(root_graph);
|
||||
for (const auto &branch_output : all_branch_output) {
|
||||
for (int j = 0; j < SizeToInt(all_branch_output.size()); ++j) {
|
||||
// In general, there is only one output branch, and the branch id is 0 at this time. In the control flow,
|
||||
// there are multi-branch output scenarios. Different branches may have different weight nodes. When output
|
||||
// actor run, the corresponding weight node needs to be obtained according to different branches. Therefore,
|
||||
// the branch of the output nodes needs to be recorded.
|
||||
const int branch_id = ((all_branch_output.size() == 1 ? runtime::kMainBranchID : (j + runtime::kSubBranchStartID)));
|
||||
const auto &branch_output = all_branch_output[j];
|
||||
size_t position = 0;
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(branch_output);
|
||||
outputs_num = outputs.size();
|
||||
for (const auto &output : outputs) {
|
||||
outputs_order.emplace(output, position++);
|
||||
outputs_order[output] = {branch_id, position++};
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -562,7 +568,7 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
|
|||
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
for (const auto &output : outputs) {
|
||||
outputs_order.emplace(output, position++);
|
||||
outputs_order[output] = {runtime::kMainBranchID, position++};
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ using ControlNodeParser = runtime::ControlNodeParser;
|
|||
using FrontToBackendNodeWithContext = runtime::FrontToBackendNodeWithContext;
|
||||
using FuncGraphToParameter = runtime::FuncGraphToParameter;
|
||||
using HostParameterToWeight = runtime::HostParameterToWeight;
|
||||
|
||||
enum SwitchCondStatus {
|
||||
kCondOk = 0,
|
||||
kCondAlreadyRun,
|
||||
|
|
Loading…
Reference in New Issue