!26752 control flow support call to call

Merge pull request !26752 from gaoyong10/runtime_second12
This commit is contained in:
i-robot 2021-11-26 03:40:33 +00:00 committed by Gitee
commit 1b0a82fc30
11 changed files with 671 additions and 108 deletions

View File

@ -208,9 +208,14 @@ void ControlActor::SendOutput(OpContext<DeviceTensor> *const context) {
// Send Partial.
for (const auto &partial_arrow : output_partial_arrows_) {
MS_EXCEPTION_IF_NULL(partial_arrow);
MS_EXCEPTION_IF_NULL(output_partial_.first);
ActorDispatcher::Send(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial_.first,
output_partial_.second, IntToSize(partial_arrow->to_input_index_), context);
if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) {
MS_LOG(ERROR) << "Invalid partial input:" << partial_arrow->from_output_index_
<< " current:" << input_partials_.size() << " for actor:" << GetAID();
}
auto output_partial = input_partials_[partial_arrow->from_output_index_];
MS_EXCEPTION_IF_NULL(output_partial.first);
ActorDispatcher::Send(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial.first,
output_partial.second, IntToSize(partial_arrow->to_input_index_), context);
}
}
} // namespace runtime

View File

@ -72,6 +72,22 @@ void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) {
ActorDispatcher::Send(control_arrow, &OpActor::RunOpControl, source_aid, context);
}
}
// 3.Send output partial in output branch.
const auto &partial_iter = output_branch_partial_arrows_.find(output_branch_id_);
if (partial_iter != output_branch_partial_arrows_.end()) {
for (const auto &partial_arrow : partial_iter->second) {
MS_EXCEPTION_IF_NULL(partial_arrow);
if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) {
MS_LOG(ERROR) << "Invalid partial input:" << partial_arrow->from_output_index_
<< " current:" << input_partials_.size() << " for actor:" << GetAID();
}
auto output_partial = input_partials_[partial_arrow->from_output_index_];
MS_EXCEPTION_IF_NULL(output_partial.first);
Async(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial.first, output_partial.second,
IntToSize(partial_arrow->to_input_index_), context);
}
}
}
void ExitActor::CopyDeviceAddress() {

View File

@ -29,23 +29,22 @@ void GatherActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
ControlActor::FetchInput(context);
output_partial_ = input_partials_[0];
MS_EXCEPTION_IF_NULL(output_partial_.first);
MS_EXCEPTION_IF_NULL(input_partials_[0].first);
// Put other real parameter in partial.
for (const auto &device_tensor : input_device_tensors_) {
if (device_tensor != nullptr) {
output_partial_.second.emplace_back(device_tensor);
input_partials_[0].second.emplace_back(device_tensor);
}
}
}
void GatherActor::SendOutput(OpContext<DeviceTensor> *const context) {
// Send data with branch id.
const auto &iter = output_data_with_branch_id_arrows_.find(output_partial_.first);
const auto &iter = output_data_with_branch_id_arrows_.find(input_partials_[0].first);
if (iter != output_data_with_branch_id_arrows_.end()) {
for (const auto &data_with_branch_id_arrow : iter->second) {
ActorDispatcher::Send(data_with_branch_id_arrow, &EntranceActor::RunOpDataWithBranchID, output_partial_.second,
ActorDispatcher::Send(data_with_branch_id_arrow, &EntranceActor::RunOpDataWithBranchID, input_partials_[0].second,
output_branch_id_, context);
}
}

View File

@ -27,25 +27,56 @@ StackActor::StackActor(const std::string &name, const std::vector<KernelWithInde
void StackActor::Init() {
ControlActor::Init();
for (const auto &formal_parameter : formal_parameters_) {
if (AnfAlgo::IsCallNode(formal_parameter.first)) {
break;
}
++input_parameter_data_num_;
}
input_datas_num_ = formal_parameters_.size() - input_parameter_data_num_;
if (input_parameter_data_num_ < device_tensor_store_keys_.size()) {
// The stack actor has 6 parts of input :
// 1. Directly input data.
// 2. Direct input partial.
// 3. Weight.
// 4. Local tensor.
// 5. Call input data.
// 6. Call input partial.
input_datas_num_ = formal_parameters_.size() - input_parameter_data_num_ - input_parameter_partial_num_;
if (input_parameter_data_num_ < device_tensor_store_keys_.size() + local_device_tensors_.size()) {
MS_LOG(EXCEPTION) << "Invalid input parameter data num:" << input_parameter_data_num_
<< " device store num:" << device_tensor_store_keys_.size() << " for actor:" << GetAID();
<< " device store num:" << device_tensor_store_keys_.size() << " local device tensor num"
<< local_device_tensors_.size() << " for actor:" << GetAID();
}
// Fetch the total number of input partial.
int total_partials_num = 0;
for (const auto &formal_parameter : formal_parameters_) {
const auto &abstract = formal_parameter.first->abstract();
MS_EXCEPTION_IF_NULL(abstract);
const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second);
MS_EXCEPTION_IF_NULL(real_abstract);
if (real_abstract->isa<abstract::AbstractFunction>()) {
total_partials_num++;
}
}
// Fetch call input data num.
input_datas_num_ = formal_parameters_.size() - total_partials_num - input_parameter_data_num_;
input_partials_num_ = total_partials_num - input_parameter_partial_num_;
// Fetch call input partial num.
input_parameter_data_num_ -= (device_tensor_store_keys_.size() + local_device_tensors_.size());
// Check if the input num is valid.
if (input_parameter_data_num_ + input_parameter_partial_num_ + input_datas_num_ + input_partials_num_ +
device_tensor_store_keys_.size() + local_device_tensors_.size() !=
formal_parameters_.size()) {
MS_LOG(EXCEPTION) << "Invalid input num, input parameter data num:" << input_parameter_data_num_
<< " input parameter partial num:" << input_parameter_partial_num_
<< " input data num:" << input_datas_num_ << " input partial num:" << input_partials_num_
<< " device tensor store size:" << device_tensor_store_keys_.size()
<< " need total size:" << formal_parameters_.size() << " for actor:" << GetAID();
}
input_parameter_data_num_ -= device_tensor_store_keys_.size();
}
void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(input_data);
MS_EXCEPTION_IF_NULL(input_data->data_);
// The parameters from the inside of the subgraph need to be put into the stack.
if (IntToSize(input_data->index_) < input_parameter_data_num_ + device_tensor_store_keys_.size()) {
if (IntToSize(input_data->index_) < input_parameter_data_num_ + device_tensor_store_keys_.size() +
input_parameter_partial_num_ + local_device_tensors_.size()) {
FillStack(input_data, context);
} else {
// The outputs of call nodes are placed directly in the input data.
@ -56,6 +87,22 @@ void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<Dev
}
}
void StackActor::RunOpPartial(FuncGraph *func_graph, std::vector<DeviceTensor *> input_data, size_t position,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(func_graph);
// The parameters from the inside of the subgraph need to be put into the stack.
if (IntToSize(position) < input_parameter_data_num_ + device_tensor_store_keys_.size() +
input_parameter_partial_num_ + local_device_tensors_.size()) {
input_parameter_partial_[context->sequential_num_][position].push(OpPartial(func_graph, input_data));
} else {
input_op_partials_[context->sequential_num_].emplace_back(position, OpPartial(func_graph, input_data));
}
if (CheckRunningCondition(context)) {
Run(context);
}
}
void StackActor::FillStack(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(input_data);
@ -122,6 +169,26 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
return false;
}
}
if (input_parameter_partial_num_ != 0) {
const auto &partial_iter = input_parameter_partial_.find(context->sequential_num_);
if (partial_iter == input_parameter_partial_.end()) {
return false;
}
if (partial_iter->second.size() != input_parameter_partial_num_) {
return false;
}
auto iter = input_branch_ids_.find(context->sequential_num_);
if (iter == input_branch_ids_.end() || iter->second.empty()) {
MS_LOG(ERROR) << "There is no branch id for actor:" << GetAID();
}
size_t branch_id_size = iter->second.size();
if (std::any_of(partial_iter->second.begin(), partial_iter->second.end(),
[branch_id_size](const auto &one_stack) { return one_stack.second.size() != branch_id_size; })) {
return false;
}
}
return true;
}
@ -133,13 +200,29 @@ void StackActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
}
for (const auto &one_stack : data_iter->second) {
if (one_stack.first >= input_parameter_data_num_ + device_tensor_store_keys_.size()) {
if (one_stack.first >= input_parameter_data_num_ + device_tensor_store_keys_.size() +
local_device_tensors_.size() + input_parameter_partial_num_) {
MS_LOG(ERROR) << "Invalid input index:" << one_stack.first << " need:" << input_parameter_data_num_
<< " for actor:" << GetAID();
}
input_device_tensors_[one_stack.first] = one_stack.second.top();
}
}
if (input_parameter_partial_num_ != 0) {
const auto &partial_iter = input_parameter_partial_.find(context->sequential_num_);
if (partial_iter == input_parameter_partial_.end()) {
MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
}
for (const auto &one_stack : partial_iter->second) {
if (one_stack.first >= input_parameter_data_num_ + device_tensor_store_keys_.size() +
local_device_tensors_.size() + input_parameter_partial_num_) {
MS_LOG(ERROR) << "Invalid input index:" << one_stack.first << " need:" << input_parameter_partial_
<< " for actor:" << GetAID();
}
input_partials_[one_stack.first] = one_stack.second.top();
}
}
ControlActor::FetchInput(context);
}
@ -160,6 +243,20 @@ void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) {
one_stack.second.pop();
}
}
if (input_parameter_partial_num_ != 0) {
const auto &partial_iter = input_parameter_partial_.find(context->sequential_num_);
if (partial_iter == input_parameter_partial_.end()) {
MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
}
for (auto &one_stack : partial_iter->second) {
if (one_stack.second.empty()) {
MS_LOG(ERROR) << "Input index:" << one_stack.first << " is null in actor:" << GetAID();
}
one_stack.second.pop();
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -39,9 +39,11 @@ class StackActor : public ControlActor {
~StackActor() override = default;
void Init() override;
// The input data of the stack actor needs to be pushed into the stack according to the input index, so it is
// implemented separately.
// The input data and partial of the stack actor needs to be pushed into the stack according to the input index,
// so it is implemented separately.
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
void RunOpPartial(FuncGraph *func_graph, std::vector<DeviceTensor *> input_data, size_t position,
OpContext<DeviceTensor> *const context) override;
protected:
void FetchInput(OpContext<DeviceTensor> *const context) override;
@ -56,12 +58,15 @@ class StackActor : public ControlActor {
// The device tensors created and stored by the stack.
std::vector<DeviceTensorPtr> created_device_tensors_;
// The input data records that the stack actor is copied from the input nodes and needs to be stored in the
// device tensor in the stack.
// The input data and partials records that the stack actor is copied from the input nodes and needs to be
// stored in the device tensor in the stack.
mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<DeviceTensor *>>> input_parameter_data_;
// Input parameter data num represents the number of actor's input come from funcgraph itself, these inputs
// will be ranked at the front of input.
mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<OpPartial>>> input_parameter_partial_;
// Input parameter num represents the number of actor's input come from funcgraph itself, these inputs will
// be ranked at the front of input.
size_t input_parameter_data_num_{0};
size_t input_parameter_partial_num_{0};
// The backend parameter is used to save the backend node corresponding to the device tensor in the stack.
// When these device tensors are used as output, they need to be placed in the node of the result arrow,
// so these nodes need to be saved.

View File

@ -54,7 +54,7 @@ void SwitchActor::FetchInput(OpContext<DeviceTensor> *const context) {
if (!output_partial_arrows_.empty()) {
auto func_graph = input_partials_[index + kSwitchCondPos].first;
MS_EXCEPTION_IF_NULL(func_graph);
output_partial_ = input_partials_[index + kSwitchCondPos];
input_partials_[0] = input_partials_[index + kSwitchCondPos];
}
for (auto &output_data : output_data_by_output_index_[kSwitchDefaultOutputNum - 1]) {

View File

@ -261,7 +261,7 @@ size_t HostQueueDataSourceActor::FetchNodePosition(const AnfNodePtr &data_node)
MS_EXCEPTION_IF_NULL(data_node);
const auto &iter = data_node_position_map_.find(data_node);
if (iter == data_node_position_map_.end()) {
MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope() << " is not exist.";
MS_LOG(EXCEPTION) << "Data node: " << data_node->DebugString() << " is not exist.";
}
return iter->second;
}

View File

@ -261,6 +261,65 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(node) << " addr:" << address;
AnfAlgo::SetOutputAddr(address, front_node_with_index.second, node.get());
}
// Check if there is a recursive call to funcgraph, if a calls b, b calls c, and c calls a, it is a recursive call.
bool IsRecursionFunction(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *checked_funcgraphs,
const std::unordered_map<FuncGraphPtr, std::set<FuncGraphPtr>> &func_graph_call_relation) {
MS_EXCEPTION_IF_NULL(func_graph);
if (checked_funcgraphs->find(func_graph) != checked_funcgraphs->end()) {
return true;
}
checked_funcgraphs->emplace(func_graph);
auto iter = func_graph_call_relation.find(func_graph);
if (iter == func_graph_call_relation.end()) {
return false;
}
for (const auto &called_func_graph : iter->second) {
MS_EXCEPTION_IF_NULL(called_func_graph);
if (IsRecursionFunction(called_func_graph, checked_funcgraphs, func_graph_call_relation)) {
return true;
}
}
return false;
}
// Fetch all inputs of node.
std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
// The node is divided into the following types:
// 1. depend and load.
const auto &node_with_index =
AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
auto real_node = node_with_index.first;
MS_EXCEPTION_IF_NULL(real_node);
std::vector<KernelWithIndex> results;
// 2. MakeTuple.
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
const auto &make_tuple_cnode = real_node->cast<CNodePtr>();
const auto &make_tuple_inputs = make_tuple_cnode->inputs();
for (size_t i = kMakeTupleInputStartPos; i < make_tuple_inputs.size(); ++i) {
const auto &sub_results = FetchInputNodeByNode(make_tuple_inputs[i]);
results.insert(results.end(), sub_results.begin(), sub_results.end());
}
return results;
}
// 3. One output node.
const auto &abstract = real_node->abstract();
if (abstract == nullptr || (!abstract->isa<abstract::AbstractTuple>())) {
if (abstract == nullptr) {
MS_LOG(WARNING) << "Empty abstract for node:" << real_node->DebugString();
}
return {AnfAlgo::VisitKernelWithReturnType(real_node, 0)};
}
// 4. Abstract is Tuple.
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
for (size_t i = 0; i < output_num; ++i) {
results.emplace_back(real_node, i);
}
return results;
}
} // namespace
bool HasAbstractRef(const AnfNodePtr &node) {
@ -285,6 +344,72 @@ KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, Kernel
return front_node_with_index;
}
std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return {};
}
std::vector<KernelWithIndex> results;
// The first input of normal cnode is the primitive of node, and the real input starts from the second input,
// but in control flow, the call node has no primitive, and the 0th input is funcgraph or partial.
size_t input_start_pos = kCNodeInputStartPos;
if (AnfAlgo::IsCallNode(node)) {
input_start_pos = 0;
}
const auto &cnode = node->cast<CNodePtr>();
const auto inputs = cnode->inputs();
// The first branch of the input of the switch node is the true branch, and the second is the false branch.
// But in switch actor, since the false value is 0, it corresponds to the first branch. Therefore, the input
// of the switch node needs to exchange the positions of the two branches. So deal separately.
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
if (inputs.size() != kSwitchInputNum) {
MS_LOG(EXCEPTION) << "Invalid switch node:" << node->DebugString();
}
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0));
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchFalseBranchPos], 0));
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchTrueBranchPos], 0));
return results;
}
for (size_t i = input_start_pos; i < inputs.size(); ++i) {
MS_EXCEPTION_IF_NULL(inputs[i]);
// skip monad node.
if (HasAbstractMonad(inputs[i])) {
continue;
}
const auto &sub_results = FetchInputNodeByNode(inputs[i]);
results.insert(results.end(), sub_results.begin(), sub_results.end());
}
return results;
}
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
MS_EXCEPTION_IF_NULL(abstract);
if (!abstract->isa<abstract::AbstractTuple>()) {
if (index != 0) {
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
}
return abstract;
}
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
size_t real_index = index;
for (const auto &sub_abstract : sub_abstracts) {
size_t tmp_index = AnfAlgo::GetOutputNumByAbstract(sub_abstract);
if (real_index >= tmp_index) {
real_index -= tmp_index;
continue;
}
return FetchAbstractByIndex(sub_abstract, real_index);
}
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
return nullptr;
}
void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph,
const FuncGraphToKernelGraph &func_graph_to_kernel_graphs) {
@ -309,12 +434,18 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
ParseCallNodeToFuncGraph(control_nodes);
FetchNeedStackControlNode(control_nodes);
ParseUnRecursionCallNode();
FetchFrontNodeToKernelGraph(graphs);
ParseFormalToRealParameter(control_nodes);
ParseFrontToBackendParameter(graphs, device_contexts);
CreateDeviceTensorForRootGraphParameter(device_contexts[0]);
FetchFrontToBackendKernel(graphs, device_contexts);
FetchHostParameterToWeight();
@ -323,7 +454,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
ParseDeviceContext(control_nodes, graphs, device_contexts, func_graph_to_kernel_graphs);
FetchFrontValueNode(device_contexts[0]);
FetchFrontValueNode(control_nodes, device_contexts[0]);
FetchControlNodeParameter(control_nodes);
@ -363,6 +494,11 @@ bool ControlNodeParser::IsRootGraphParameter(const AnfNodePtr &node) {
return find(root_graph_parameters_.begin(), root_graph_parameters_.end(), node) != root_graph_parameters_.end();
}
bool ControlNodeParser::IsRecursionCallNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return find(unrecursion_call_nodes_.begin(), unrecursion_call_nodes_.end(), node) == unrecursion_call_nodes_.end();
}
void ControlNodeParser::ParseDeviceContext(const std::vector<AnfNodePtr> &control_nodes,
const std::vector<KernelGraphPtr> &kernel_graphs,
const std::vector<DeviceContext *> &device_contexts,
@ -604,6 +740,16 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def
void ControlNodeParser::FetchFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs) {
for (const auto &graph : graphs) {
if (graph->execution_order().empty()) {
continue;
}
for (auto &kernel : graph->execution_order()) {
auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
if (front_node != nullptr) {
front_node_to_kernel_graph_[front_node] = graph;
}
}
const auto &graph_outputs = graph->graph_output_map();
for (const auto &backend_to_front : graph_outputs) {
front_node_to_kernel_graph_[backend_to_front.second.first] = graph;
@ -654,7 +800,8 @@ FuncGraphPtr ControlNodeParser::FetchFuncGraphByKernelGraph(const KernelGraph *c
return nullptr;
}
void ControlNodeParser::FetchFrontValueNode(DeviceContext *default_context) {
void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes,
DeviceContext *default_context) {
MS_EXCEPTION_IF_NULL(default_context);
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
@ -688,6 +835,27 @@ void ControlNodeParser::FetchFrontValueNode(DeviceContext *default_context) {
front_value_nodes_.emplace(front_to_backend_parameters.first, device_context);
}
}
// Create device tensors for those value nodes which direct return by a return node.
for (const auto &control_node : control_nodes) {
if (!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
continue;
}
auto input_with_indexs = FetchInputNodeByCNode(control_node);
auto iter = control_node_to_device_contexts_.find(control_node);
if (iter == control_node_to_device_contexts_.end() || iter->second.size() != input_with_indexs.size()) {
MS_LOG(EXCEPTION) << "Invalid device context for control node:" << control_node->DebugString();
}
for (size_t i = 0; i < input_with_indexs.size(); ++i) {
const auto &input_with_index = input_with_indexs[i];
if (input_with_index.first->isa<ValueNode>() && (!IsValueNode<FuncGraph>(input_with_index.first)) &&
front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) {
CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]);
front_value_nodes_.emplace(input_with_index, iter->second[i]);
}
}
}
}
void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr> &control_nodes) {
@ -977,7 +1145,9 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vector<KernelGraphP
const auto graph_output_map = graph->graph_output_map();
for (const auto &output_pair : graph_output_map) {
front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context};
if (output_pair.first.first->isa<CNode>()) {
front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context};
}
}
}
}
@ -1039,5 +1209,57 @@ void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector<AnfN
}
}
}
void ControlNodeParser::ParseUnRecursionCallNode() {
std::unordered_map<FuncGraphPtr, std::set<FuncGraphPtr>> func_graph_call_relation;
// Collect the call relationship between funcgraphs.
for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) {
const auto &call_node = call_node_to_func_graphs.first;
MS_EXCEPTION_IF_NULL(call_node);
const auto &func_graph = call_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
func_graph_call_relation[func_graph].insert(call_node_to_func_graphs.second.begin(),
call_node_to_func_graphs.second.end());
}
for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) {
const auto &call_node = call_node_to_func_graphs.first;
std::set<FuncGraphPtr> checked_func_graphs{call_node->func_graph()};
bool is_recursion_call_node = false;
if (std::any_of(call_node_to_func_graphs.second.begin(), call_node_to_func_graphs.second.end(),
[&is_recursion_call_node, &checked_func_graphs, &func_graph_call_relation](const auto &func_graph) {
return IsRecursionFunction(func_graph, &checked_func_graphs, func_graph_call_relation);
})) {
is_recursion_call_node = true;
}
if (!is_recursion_call_node && need_stack_control_nodes_.find(call_node) == need_stack_control_nodes_.end()) {
unrecursion_call_nodes_.emplace(call_node);
}
}
}
void ControlNodeParser::FetchNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes) {
for (const auto &control_node : control_nodes) {
MS_EXCEPTION_IF_NULL(control_node);
if (AnfAlgo::IsCallNode(control_node)) {
auto input_with_indexs = FetchInputNodeByCNode(control_node);
if (std::any_of(input_with_indexs.begin(), input_with_indexs.end(),
[](const auto &input_with_index) { return AnfAlgo::IsCallNode(input_with_index.first); })) {
need_stack_control_nodes_.emplace(control_node);
}
}
}
}
void ControlNodeParser::CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context) {
MS_EXCEPTION_IF_NULL(default_context);
for (const auto &parameter : root_graph_parameters_) {
KernelWithIndex parameter_with_index(parameter, 0);
if (front_to_backend_parameters_.find(parameter_with_index) == front_to_backend_parameters_.end()) {
CreateDeviceTensorForFrontNode(parameter_with_index, default_context);
front_to_backend_parameters_[parameter_with_index].emplace(parameter, default_context);
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include <queue>
#include <map>
#include <utility>
#include <unordered_map>
#include <algorithm>
#include "utils/hash_map.h"
#include "runtime/hardware/device_context.h"
@ -76,7 +77,10 @@ bool HasAbstractRef(const AnfNodePtr &node);
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
// corresponding cnode.
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, KernelGraph *const graph);
// Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem.
std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node);
// Fetch the sub abstract from the top abstract by the index.
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index);
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
class ControlNodeParser {
public:
@ -94,6 +98,7 @@ class ControlNodeParser {
// 2. In the kernel graph with call node input, the data arrow needs to be connected to the stack actor.
bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node);
bool IsRootGraphParameter(const AnfNodePtr &node);
bool IsRecursionCallNode(const AnfNodePtr &node);
const std::vector<AnfNodePtr> &control_node_parameters() const { return control_node_parameters_; }
const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; }
@ -117,7 +122,7 @@ class ControlNodeParser {
// value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
// them separately during initialization.
// The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
void FetchFrontValueNode(DeviceContext *default_context);
void FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, DeviceContext *default_context);
// Create branch id for all call node in the control flow.
void CreateBranchIDForCallNode(const std::vector<AnfNodePtr> &control_nodes);
@ -138,6 +143,8 @@ class ControlNodeParser {
const FormalToRealParameter &formal_to_real_parameters,
std::set<KernelWithIndex> *total_real_parameters,
std::set<AnfNodePtr> *invalid_real_parameter);
// Get all the call nodes without a recursion call relation.
void ParseUnRecursionCallNode();
// Parse the device context of the control node. In a heterogeneous scenario, different device contexts need to be
// copied between different device memories. The analysis steps:
@ -179,7 +186,12 @@ class ControlNodeParser {
void FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes);
// Fetch the formal parameter in root graph by parameters in subgraph.
AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node);
// Get the control nodes which need to add a stack actor for them.
// When a control node has input that is a call node, you need to add a stack actor for it.
void FetchNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes);
// When the parameter is directly used as the condition of the switch, there will be no back-end node, and a device
// tensor needs to be created for it.
void CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context);
// In control flow, funcgraph will be cut into multiple kernel graphs for execution, and this relationship is recorded
// in this map.
FuncGraphToKernelGraph func_graph_to_kernel_graphs_;
@ -220,6 +232,12 @@ class ControlNodeParser {
mindspore::HashMap<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
// Control nodes without a control node input in the topological sorting of funcgraph.
mindspore::HashMap<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_;
// Call nodes without recursive call. The funcgraphs of the call will not call the funcgraph where the call node
// belong.
std::set<AnfNodePtr> unrecursion_call_nodes_;
// Those control nodes that need to create the corresponding stack actor, when there is a call node in the inputs
// of the control node, the stack actor is needed to collect these inputs.
std::set<AnfNodePtr> need_stack_control_nodes_;
// In heterogeneous scenario, each parameter has its own device context type, so the device context corresponding
// to the type needs to be parsed in advance so that it can add some copy operation in the scheduler.

View File

@ -35,66 +35,41 @@ std::string GetActorName(const AnfNodePtr &node) {
}
}
// Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem.
std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return {};
// Fetch the depend nodes according to the monad node.
void FetchRealDependNodeByAutoMonad(const AnfNodePtr &node, std::set<AnfNodePtr> *depend_nodes) {
// Find the real input node, include the monad node and make tuple node.
const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
prim::kPrimMakeTuple};
const auto &node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_types);
auto real_node = node_with_index.first;
MS_EXCEPTION_IF_NULL(real_node);
if (!real_node->isa<CNode>()) {
return;
}
std::vector<KernelWithIndex> results;
// The first input of normal cnode is the primitive of node, and the real input starts from the second input,
// but in control flow, the call node has no primitive, and the 0th input is funcgraph or partial.
size_t input_start_pos = kCNodeInputStartPos;
if (AnfAlgo::IsCallNode(node)) {
input_start_pos = 0;
}
const auto &cnode = node->cast<CNodePtr>();
const auto inputs = cnode->inputs();
const auto &real_cnode = real_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(real_cnode);
const auto &real_inputs = real_cnode->inputs();
// The first branch of the input of the switch node is the true branch, and the second is the false branch.
// But in switch actor, since the false value is 0, it corresponds to the first branch. Therefore, the input
// of the switch node needs to exchange the positions of the two branches. So deal separately.
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
if (inputs.size() != kSwitchInputNum) {
MS_LOG(EXCEPTION) << "Invalid switch node:" << node->DebugString();
// Make tuple node needs to be expanded.
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
for (size_t i = 1; i < real_inputs.size(); ++i) {
MS_EXCEPTION_IF_NULL(real_inputs[i]);
FetchRealDependNodeByAutoMonad(real_inputs[i], depend_nodes);
}
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0));
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchFalseBranchPos], 0));
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchTrueBranchPos], 0));
return results;
return;
}
for (size_t i = input_start_pos; i < inputs.size(); ++i) {
MS_EXCEPTION_IF_NULL(inputs[i]);
// skip monad node.
if (HasAbstractMonad(inputs[i])) {
continue;
}
const auto &node_with_index =
AnfAlgo::VisitKernelWithReturnType(inputs[i], 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
MS_EXCEPTION_IF_NULL(node_with_index.first);
size_t output_num = AnfAlgo::GetOutputTensorNum(node_with_index.first);
for (size_t j = 0; j < output_num; ++j) {
if (AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimMakeTuple)) {
const auto &make_tuple_cnode = node_with_index.first->cast<CNodePtr>();
const auto &make_tuple_inputs = make_tuple_cnode->inputs();
if (make_tuple_inputs.size() <= j + kMakeTupleInputStartPos) {
MS_LOG(EXCEPTION) << "Invalid input:" << j + kMakeTupleInputStartPos
<< " for make tuple node:" << make_tuple_cnode->DebugString();
}
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(make_tuple_inputs[j + kMakeTupleInputStartPos], 0));
} else if (node_with_index.first->isa<ValueNode>()) {
// When the value node is a value tuple, the value node will have multiple outputs, which need to be directly
// put into the vector, and the output cannot be obtained through the VisitKernelWithReturnType interface.
results.emplace_back(node_with_index.first, j);
} else {
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(node_with_index.first, j));
}
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimDepend) ||
AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimLoad)) {
FetchRealDependNodeByAutoMonad(real_inputs[kDependAttachNodeIndex], depend_nodes);
} else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimUpdateState)) {
for (size_t i = kUpdateStateRealInput; i < real_inputs.size(); ++i) {
FetchRealDependNodeByAutoMonad(real_inputs[i], depend_nodes);
}
} else {
depend_nodes->emplace(real_node);
}
return results;
}
} // namespace
@ -283,6 +258,8 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
// Create a corresponding stack actor for each kernel graph that has a call node as input.
for (const auto &graph_with_context : parser->call_input_kernel_graphs_) {
std::vector<KernelWithIndex> formal_parameters;
size_t input_parameter_data_num = 0;
const auto &graph = graph_with_context.first;
const auto &device_context = graph_with_context.second;
MS_EXCEPTION_IF_NULL(graph);
@ -303,10 +280,12 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
}
// If the input comes from inside funcgraph, put it at the front of the vector, otherwise put it at the end.
if (AnfAlgo::IsCallNode(front_node_with_index.first)) {
if (AnfAlgo::IsCallNode(front_node_with_index.first) &&
(parser->IsRecursionCallNode(front_node_with_index.first))) {
formal_parameters.emplace_back(front_node_with_index);
} else {
formal_parameters.insert(formal_parameters.begin(), front_node_with_index);
input_parameter_data_num++;
}
}
@ -315,11 +294,79 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
stack_actors.emplace_back(stack_actor);
stack_actor->device_contexts_.insert(stack_actor->device_contexts_.begin(), formal_parameters.size(),
device_context);
stack_actor->input_parameter_data_num_ = input_parameter_data_num;
InsertActor(stack_actor.get());
}
// Create stack actors for control nodes.
BuildStackActorForControlNode(graph_compiler_info, &stack_actors);
return stack_actors;
}
void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo &graph_compiler_info,
std::vector<StackActorPtr> *stack_actors) {
const auto &parser = graph_compiler_info.control_node_parser_;
MS_EXCEPTION_IF_NULL(parser);
for (const auto &need_stack_control_node : parser->need_stack_control_nodes_) {
MS_EXCEPTION_IF_NULL(need_stack_control_node);
if (!AnfAlgo::IsCallNode(need_stack_control_node)) {
continue;
}
std::vector<KernelWithIndex> formal_parameters;
std::vector<const DeviceContext *> device_contexts;
size_t input_parameter_data_num = 0;
size_t input_parameter_partials_num = 0;
// Fetch the control actor of control node.
auto gather_actor_name = GetActorName(need_stack_control_node);
auto actor = FetchActor(gather_actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
if (gather_actor->formal_parameters_.size() > gather_actor->device_contexts_.size()) {
MS_LOG(EXCEPTION) << "Invalid device context size:" << gather_actor->device_contexts_.size()
<< " and formal parameter size:" << gather_actor->formal_parameters_.size()
<< " for actor:" << gather_actor->GetAID();
}
// Collect formal parameters and device contexts, skip the value nodes.
for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
const auto &parameter = gather_actor->formal_parameters_[i];
auto device_context = gather_actor->device_contexts_[i];
if (AnfAlgo::IsCallNode(parameter.first) && (parser->IsRecursionCallNode(parameter.first))) {
formal_parameters.emplace_back(parameter);
device_contexts.emplace_back(device_context);
} else if (parameter.first->isa<ValueNode>()) {
continue;
} else {
formal_parameters.insert(formal_parameters.begin(), parameter);
device_contexts.insert(device_contexts.begin(), device_context);
const auto &abstract = parameter.first->abstract();
MS_EXCEPTION_IF_NULL(abstract);
const auto &real_abstract = FetchAbstractByIndex(abstract, parameter.second);
MS_EXCEPTION_IF_NULL(real_abstract);
if (real_abstract->isa<abstract::AbstractFunction>()) {
input_parameter_partials_num++;
} else {
input_parameter_data_num++;
}
}
}
// Create stack actor.
const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix;
const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, formal_parameters);
stack_actor->device_contexts_ = device_contexts;
stack_actor->input_parameter_data_num_ = input_parameter_data_num;
stack_actor->input_parameter_partial_num_ = input_parameter_partials_num;
InsertActor(stack_actor.get());
stack_actors->emplace_back(stack_actor);
}
}
void ControlNodeScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(actor_set);
MS_EXCEPTION_IF_NULL(actor_set->control_actors_);
@ -362,11 +409,18 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
}
for (auto &gather_actor : control_actor_set->gather_actors_) {
for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i},
parser);
MS_EXCEPTION_IF_NULL(gather_actor->node_);
if (parser->need_stack_control_nodes_.find(gather_actor->node_) == parser->need_stack_control_nodes_.end()) {
for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i},
parser);
}
} else {
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
LinkArrowFromStackActor(gather_actor.get());
}
}
for (auto &entrance_actor : control_actor_set->entrance_actors_) {
for (const auto &call_node : entrance_actor->call_nodes_) {
LinkArrowbyFormalParameter(entrance_actor.get(), call_node, {entrance_actor->node_, 0}, parser);
@ -387,6 +441,38 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
}
}
void ControlNodeScheduler::LinkArrowFromStackActor(ControlActor *to_actor) {
MS_EXCEPTION_IF_NULL(to_actor->node_);
auto stack_actor_name = GetActorName(to_actor->node_) + kStackActorNameSuffix;
auto actor = FetchActor(stack_actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto stack_actor = dynamic_cast<StackActor *>(actor);
MS_EXCEPTION_IF_NULL(stack_actor);
for (size_t to_index = 0; to_index < to_actor->formal_parameters_.size(); ++to_index) {
const auto &formal_parameter = to_actor->formal_parameters_[to_index];
const auto &from_node = formal_parameter.first;
if (from_node->isa<ValueNode>()) {
LinkArrowByValueNode(from_node, to_actor, formal_parameter.second, to_index);
continue;
}
// Fetch the arrow type of input.
size_t from_index = stack_actor->FetchNodePosition(formal_parameter);
const auto &abstract = formal_parameter.first->abstract();
MS_EXCEPTION_IF_NULL(abstract);
const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second);
MS_EXCEPTION_IF_NULL(real_abstract);
// Link arrow according to abstract.
if (real_abstract->isa<abstract::AbstractFunction>()) {
LinkPartialArrow(stack_actor, to_actor, from_index, to_index);
} else {
LinkDataArrow(stack_actor, to_actor, from_index, to_index);
}
}
}
void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_actor,
const KernelWithIndex &from_node_with_index,
const KernelWithIndex &to_node_with_index,
@ -394,20 +480,7 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
const auto &from_node = from_node_with_index.first;
MS_EXCEPTION_IF_NULL(from_node);
if (from_node->isa<ValueNode>()) {
if (IsValueNode<FuncGraph>(from_node)) {
// Link local partial.
const auto &func_graph = GetValueNode<FuncGraphPtr>(from_node);
to_actor->local_partials_[to_node_with_index.second] = OpPartial(func_graph.get(), {});
} else {
// Link device store value node.
if (!AnfAlgo::OutputAddrExist(from_node, from_node_with_index.second)) {
MS_LOG(EXCEPTION) << "Invalid output address index:" << from_node_with_index.second
<< " for value node:" << from_node->DebugString();
}
to_actor->local_device_tensors_[to_node_with_index.second] =
AnfAlgo::GetMutableOutputAddr(from_node, from_node_with_index.second, false).get();
to_actor->local_device_tensors_[to_node_with_index.second]->SetNodeIndex(from_node, from_node_with_index.second);
}
LinkArrowByValueNode(from_node, to_actor, from_node_with_index.second, to_node_with_index.second);
} else if (from_node->isa<Parameter>()) {
LinkArrowByParameter(from_node, to_actor, from_node_with_index, to_node_with_index, parser);
} else if (AnfAlgo::IsCallNode(from_node)) {
@ -436,6 +509,26 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
}
}
void ControlNodeScheduler::LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor,
size_t from_index, size_t to_index) {
MS_EXCEPTION_IF_NULL(value_node);
MS_EXCEPTION_IF_NULL(to_actor);
if (IsValueNode<FuncGraph>(value_node)) {
// Link local partial.
const auto &func_graph = GetValueNode<FuncGraphPtr>(value_node);
to_actor->local_partials_[to_index] = OpPartial(func_graph.get(), {});
} else {
// Link device store value node.
if (!AnfAlgo::OutputAddrExist(value_node, from_index)) {
MS_LOG(EXCEPTION) << "Invalid output address index:" << from_index
<< " for value node:" << value_node->DebugString();
}
to_actor->local_device_tensors_[to_index] = AnfAlgo::GetMutableOutputAddr(value_node, from_index, false).get();
to_actor->local_device_tensors_[to_index]->SetNodeIndex(value_node, from_index);
}
}
void ControlNodeScheduler::LinkArrowByParameter(const AnfNodePtr &parameter, ControlActor *const to_actor,
const KernelWithIndex &from_node_with_index,
const KernelWithIndex &to_node_with_index,
@ -467,6 +560,11 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
if (to_actor->type_ != KernelTransformType::kEntranceActor) {
// Link arrow from exit actor to control actor.
const auto &abstract = call_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
const auto &real_abstract = FetchAbstractByIndex(abstract, from_node_with_index.second);
MS_EXCEPTION_IF_NULL(real_abstract);
const auto &func_graphs = AnfAlgo::GetFuncGraphbyCallNode(from_node);
for (const auto &func_graph : func_graphs) {
MS_EXCEPTION_IF_NULL(func_graph);
@ -475,10 +573,19 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
MS_EXCEPTION_IF_NULL(actor);
auto exit_actor = dynamic_cast<ExitActor *>(actor);
size_t branch_id = parser->FetchBranchIDByCallNode(from_node);
LinkDataArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second,
branch_id);
if (real_abstract->isa<abstract::AbstractFunction>()) {
LinkPartialArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second,
branch_id);
} else {
LinkDataArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second,
branch_id);
}
}
if (abstract->isa<abstract::AbstractFunction>()) {
to_actor->input_partials_num_++;
} else {
to_actor->input_datas_num_++;
}
to_actor->input_datas_num_++;
} else {
// Link arrow from gather actor to entrance actor.
const auto &actor_name = GetActorName(from_node);
@ -604,6 +711,34 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
LinkControlArrow(entrance_actor, control_actor);
}
}
// Link auto monad control arrow for control actor.
std::vector<ControlActor *> control_actors;
(void)std::transform(control_actor_set->switch_actors_.begin(), control_actor_set->switch_actors_.end(),
std::back_inserter(control_actors), [](auto &switch_actor) { return switch_actor.get(); });
(void)std::transform(control_actor_set->gather_actors_.begin(), control_actor_set->gather_actors_.end(),
std::back_inserter(control_actors), [](auto &gather_actor) { return gather_actor.get(); });
(void)std::transform(control_actor_set->exit_actors_.begin(), control_actor_set->exit_actors_.end(),
std::back_inserter(control_actors), [](auto &exit_actor) { return exit_actor.get(); });
for (auto control_actor : control_actors) {
MS_EXCEPTION_IF_NULL(control_actor);
const auto &node = control_actor->node_;
if (node == nullptr) {
return;
}
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input);
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimUpdateState) ||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimDepend) ||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimLoad)) {
LinkControlArrowByAutoMonad(control_actor, input, parser);
}
}
}
}
void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_set,
@ -613,6 +748,13 @@ void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_
// Link control arrow from entrance actors or stack actors to no input kernel actors.
for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
// In control flow, when the input of the kernel actor is a parameter, this input needs to be linked to the
// control actor, so the no-input kernel actor collected in the graph scheduler will also collect this actor,
// and it needs to be skipped here.
if ((no_input_kernel_actor->input_datas_num_ != 0) || (no_input_kernel_actor->input_controls_num_ != 0)) {
continue;
}
KernelGraphPtr kernel_graph = nullptr;
if (no_input_kernel_actor->type_ == KernelTransformType::kSuperKernelActor) {
const auto &super_kernel_actor = dynamic_cast<SuperKernelActor *>(no_input_kernel_actor.get());
@ -665,6 +807,46 @@ void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_
}
}
void ControlNodeScheduler::LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node,
const ControlNodeParserPtr &parser) {
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(from_node);
std::set<AnfNodePtr> depend_nodes;
FetchRealDependNodeByAutoMonad(from_node, &depend_nodes);
for (const auto &depend_node : depend_nodes) {
MS_EXCEPTION_IF_NULL(depend_node);
auto from_actor = FetchActor(depend_node->DebugString());
if (AnfAlgo::IsCallNode(depend_node)) {
int branch_id = parser->FetchBranchIDByCallNode(depend_node);
const auto &func_graphs = parser->FetchFuncGraphbyCallNode(depend_node);
if (func_graphs.empty()) {
MS_LOG(EXCEPTION) << "Failed to get funcgraph by call node:" << depend_node->DebugString();
}
for (const auto func_graph : func_graphs) {
auto exit_actor_name = func_graph->ToString() + kExitActorNameSuffix;
auto actor = FetchActor(exit_actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto exit_actor = dynamic_cast<ExitActor *>(actor);
MS_EXCEPTION_IF_NULL(exit_actor);
LinkControlArrowForExitActor(exit_actor, to_actor, branch_id);
}
to_actor->input_controls_num_ -= (func_graphs.size() - 1);
} else if (from_actor != nullptr) {
LinkControlArrow(from_actor, to_actor);
} else {
auto graph = parser->FetchKernelGraphByFrontNode(depend_node);
if (graph == nullptr) {
MS_LOG(EXCEPTION) << "Failed to find actor for node:" << depend_node->DebugString();
}
from_actor = FetchActor(graph->ToString() + kExitActorNameSuffix);
MS_EXCEPTION_IF_NULL(from_actor);
LinkControlArrow(from_actor, to_actor);
}
}
}
void ControlNodeScheduler::LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set) {
MS_EXCEPTION_IF_NULL(control_actor_set);
@ -886,6 +1068,15 @@ void ControlNodeScheduler::LinkDataArrowForExitActor(ExitActor *const exit_actor
(void)to_actor->input_data_arrow_aids_.emplace_back(exit_actor->GetAID());
}
void ControlNodeScheduler::LinkPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor,
size_t from_index, size_t to_index, int branch_id) {
MS_EXCEPTION_IF_NULL(exit_actor);
MS_EXCEPTION_IF_NULL(to_actor);
auto partial_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
(void)exit_actor->output_branch_partial_arrows_[branch_id].emplace_back(partial_arrow);
(void)to_actor->input_partial_arrow_aids_.emplace_back(exit_actor->GetAID());
}
void ControlNodeScheduler::LinkControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);

View File

@ -50,7 +50,8 @@ class ControlNodeScheduler {
std::vector<EntranceActorPtr> BuildEntranceActor(const GraphCompilerInfo &graph_compiler_info);
std::vector<ExitActorPtr> BuildExitActor(const GraphCompilerInfo &graph_compiler_info);
std::vector<StackActorPtr> BuildStackActor(const GraphCompilerInfo &graph_compiler_info);
void BuildStackActorForControlNode(const GraphCompilerInfo &graph_compiler_info,
std::vector<StackActorPtr> *stack_actors);
// Interface to link control actors.
void LinkControlArrowForControlActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set);
@ -67,6 +68,10 @@ class ControlNodeScheduler {
void LinkArrowByParameter(const AnfNodePtr &parameter, ControlActor *const to_actor,
const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index,
const ControlNodeParserPtr &parser);
void LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor, size_t from_index,
size_t to_index);
// Link arrow from stack actor to control actor.
void LinkArrowFromStackActor(ControlActor *to_actor);
// Link data arrow between control actor and actor in frame, including kernel actor, output actor, data source actor.
void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info);
@ -75,6 +80,9 @@ class ControlNodeScheduler {
void LinkDataArrowForOutputActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkDataArrowForHostDSActor(const GraphCompilerInfo &graph_compiler_info);
void LinkControlArrowForKernelActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node,
const ControlNodeParserPtr &parser);
// Interface tool to link arrows between actors.
void LinkControlArrow(AbstractActor *from_actor, AbstractActor *to_actor);
// Data arrow with branch id is only exists from gather actor to entrance actor.
@ -90,6 +98,8 @@ class ControlNodeScheduler {
void LinkControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id);
void LinkDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor, size_t from_index,
size_t to_index, int branch_id);
void LinkPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor, size_t from_index,
size_t to_index, int branch_id);
bool IsNoInputActor(const ControlActor *control_actor);
};
} // namespace runtime