!19044 Fetch backend node for switch actor.

Merge pull request !19044 from gaoyong10/new_runtime14
This commit is contained in:
i-robot 2021-06-29 14:17:11 +00:00 committed by Gitee
commit ec7a1da2c9
5 changed files with 204 additions and 85 deletions

View File

@ -199,7 +199,8 @@ void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t b
const auto &node = node_with_index.first;
// Add weight and value node.
if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && HasAbstractRef(node)) || node->isa<ValueNode>()) {
if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && node->isa<Parameter>() && HasAbstractRef(node)) ||
node->isa<ValueNode>()) {
const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
if (iter != input_nodes_.end()) {
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
@ -212,7 +213,7 @@ void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t b
}
// Output of updatestate node is U, need to be skipped.
if (HasAbstractRef(node)) {
if (node->isa<Parameter>() && HasAbstractRef(node)) {
return;
}
@ -447,7 +448,7 @@ void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
void SwitchActor::FetchInputNode(const ControlNodeParserPtr &parser) {
for (size_t i = 0; i < input_nodes_.size(); ++i) {
const auto &input_node = input_nodes_[i].first;
if (!HasAbstractRef(input_node)) {
if (!(input_node->isa<Parameter>() && HasAbstractRef(input_node))) {
backend_parameters_[i] = parser->FetchBackendInputNodeByFrontNode(input_node);
continue;
}
@ -456,7 +457,7 @@ void SwitchActor::FetchInputNode(const ControlNodeParserPtr &parser) {
if (backend_weight == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight node:" << AnfAlgo::GetNodeDebugString(input_node);
}
backend_parameters_[i].push_back({backend_weight, 0});
backend_parameters_[i].insert({backend_weight, 0});
}
}
} // namespace runtime

View File

@ -19,6 +19,7 @@
#include <vector>
#include <string>
#include <set>
#include <memory>
#include <utility>
#include <stack>
@ -138,7 +139,7 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
// When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor,
// so all the nodes that may send the device tensor to switch actor are recorded.
std::vector<std::vector<KernelWithIndex>> backend_parameters_;
std::vector<std::set<KernelWithIndex>> backend_parameters_;
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
std::vector<FuncGraphPtr> branch_func_graph_;

View File

@ -17,12 +17,14 @@
#include "runtime/framework/control_node_parser.h"
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/framework/actor/gather_actor.h"
#include "abstract/utils.h"
#include "ir/tensor.h"
namespace mindspore {
namespace runtime {
namespace {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
// Fetch all the weight parameters related to node. It runs like this:
// if we have a map like {{a, {b, c}}, {b, {d, e}}}, final we will get {{a, {b, c, d, e}}, {b, {c, d}}}.
void FetchWeightbyHostParameter(const AnfNodePtr &node, std::vector<AnfNodePtr> *dest_nodes,
@ -167,6 +169,30 @@ void CreateDeviceTensorForValueNode(const AnfNodePtr &front_node, const AnfNodeP
AnfAlgo::SetOutputAddr(address, 0, front_node.get());
}
// Create a device tensor for front parameter.
// When the condition input of the switch and switchlayer or the output of a subgraph is a parameter, there is no
// corresponding backend node for this parameter, so a device tensor needs to be created for it.
void CreateDeviceTensorForFrontParameter(const AnfNodePtr &node, const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_context);
TypeId type_id = AnfAlgo::GetOutputInferDataType(node, 0);
if (node->kernel_info() == nullptr) {
auto kernel_info = std::make_shared<device::KernelInfo>();
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
builder->SetOutputsFormat({kOpFormat_DEFAULT});
builder->SetOutputsDeviceType({type_id});
kernel_info->set_select_kernel_build_info(builder->Build());
node->set_kernel_info(kernel_info);
}
size_t size = AnfAlgo::GetOutputTensorMemSize(node, 0);
// Create device tensor.
device::DeviceAddressPtr address = device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id);
MS_EXCEPTION_IF_NULL(address);
AnfAlgo::SetOutputAddr(address, 0, node.get());
}
// Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding
// backend parameter, then recursively find the backend parameters of other front parameters corresponding to the
// front_node.
@ -547,41 +573,60 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) {
void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph) {
if (graphs.size() != device_contexts.size()) {
MS_LOG(EXCEPTION) << "Graph num is not equal to device context, graph:" << graphs.size()
<< " device context num:" << device_contexts.size();
}
if (graphs.empty()) {
return;
}
root_func_graph_ = root_graph;
root_graph_parameters_ = root_graph->parameters();
CreateBranchIDForFuncGraph(control_nodes);
FetchFrontToBackendParameter(graphs, device_contexts, control_nodes);
RealToFormalNode real_to_formal_front_parameters;
FetchFrontToFrontParameter(control_nodes, &real_to_formal_front_parameters);
RealToFormalNode formal_to_real_front_parameters;
for (const auto real_to_formal_front_parameter : real_to_formal_front_parameters) {
for (const auto formal_parameter : real_to_formal_front_parameter.second) {
formal_to_real_front_parameters[formal_parameter].emplace_back(real_to_formal_front_parameter.first);
}
}
FetchFrontToBackendParameter(graphs, device_contexts, control_nodes, real_to_formal_front_parameters,
formal_to_real_front_parameters);
FetchFuncGraphToParameter(control_nodes);
FetchHostParameterToWeight(control_nodes);
FetchHostParameterToWeight(real_to_formal_front_parameters);
FetchFrontValueNode(control_nodes, graphs, device_contexts);
FetchFrontToBackendKernel(graphs, device_contexts);
control_node_parameters_ = FetchControlNodeParameter(control_nodes);
control_node_parameters_ = FetchControlNodeParameter(control_nodes, device_contexts[0]);
FetchFuncGraphCallNum(control_nodes);
FetchCallInputKernelGraph(graphs, device_contexts);
FetchBackendInputNode();
front_output_nodes_ = FetchAllBranchOutputs(root_graph);
FetchBackendInputNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters);
}
std::vector<KernelWithIndex> ControlNodeParser::GetBackendInputByParameter(const AnfNodePtr &parameter) {
return formal_to_real_parameters_[parameter];
}
std::vector<KernelWithIndex> ControlNodeParser::FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output) {
std::set<KernelWithIndex> ControlNodeParser::FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output) {
std::set<AnfNodePtr> call_nodes;
std::set<AnfNodePtr> switch_nodes;
return FetchBackendOutputByFrontOutput(front_output, &call_nodes, &switch_nodes);
std::set<KernelWithIndex> results;
FetchBackendOutputByFrontOutput(front_output, &call_nodes, &switch_nodes, &results);
return results;
}
int ControlNodeParser::GetBranchIDByFuncGraph(const FuncGraphPtr &func_graph) {
@ -762,7 +807,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
}
}
void ControlNodeParser::FetchFrontToFrontParameterMap(
void ControlNodeParser::FetchFrontToFrontParameter(
const std::vector<AnfNodePtr> &control_nodes,
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter) {
// Function used to collect the input of call node.
@ -844,7 +889,8 @@ void ControlNodeParser::FetchFrontToFrontParameterMap(
}
}
std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes) {
std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes,
DeviceContext *device_context) {
std::vector<AnfNodePtr> parameters;
for (const auto &control_node : control_nodes) {
@ -864,6 +910,29 @@ std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::
parameters.emplace_back(inputs[i]);
}
}
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) {
if (inputs.size() != kSwitchInputNum) {
MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
}
if (inputs[kSwitchCondPos]->isa<Parameter>()) {
parameters.emplace_back(inputs[kSwitchCondPos]);
}
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
if (inputs.size() != kSwitchLayerInputNum) {
MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
}
if (inputs[kSwitchLayerCondPos]->isa<Parameter>()) {
parameters.emplace_back(inputs[kSwitchLayerCondPos]);
}
}
}
for (const auto &parameter : parameters) {
auto backend_iter = front_to_backend_parameters_.find(parameter);
if (backend_iter == front_to_backend_parameters_.end()) {
CreateDeviceTensorForFrontParameter(parameter, device_context);
front_to_backend_parameters_[parameter] = {parameter, device_context};
front_parameters_.push_back({parameter, device_context});
}
}
@ -896,7 +965,7 @@ void ControlNodeParser::FetchCallInputKernelGraph(const std::vector<KernelGraphP
const auto &graph = graphs[i];
const auto &device_context = device_contexts[i];
const auto inputs = graph->parameters();
const auto inputs = graph->input_nodes();
for (const auto &input : inputs) {
const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input);
if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) {
@ -983,7 +1052,8 @@ std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph)
continue;
}
if (HasAbstractRef(AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, 0).first) ||
const auto real_front_node = AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, 0).first;
if ((real_front_node->isa<Parameter>() && HasAbstractRef(real_front_node)) ||
HasAbstractMonad(front_node_with_index.first)) {
continue;
}
@ -1002,7 +1072,9 @@ std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph)
void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes) {
const std::vector<AnfNodePtr> &control_nodes,
const RealToFormalNode &real_to_formal_front_parameters,
const RealToFormalNode &formal_to_real_front_parameters) {
if (graphs.size() != device_contexts.size()) {
MS_LOG(EXCEPTION) << "Graph num is not equal to device context num.";
}
@ -1026,7 +1098,7 @@ void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGra
for (size_t i = 0; i < graphs.size(); ++i) {
const auto &graph = graphs[i];
auto device_context = device_contexts[i];
for (const auto &parameter : graph->parameters()) {
for (const auto &parameter : graph->input_nodes()) {
const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(parameter);
if (internal_front_node.first != nullptr) {
@ -1043,21 +1115,6 @@ void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGra
}
}
// Fetch the mapping relationship between front parameters and backend parameters in the control nodes. First
// fetch the mapping relationship of the front parameter. When the input of the call node or the partial node
// is a parameter node, it means that the parameter is directly transmitted. If a parameter does not have a
// corresponding backend node, then recursively find whether the front parameter corresponding to the parameter
// has one.
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> real_to_formal_front_parameters;
FetchFrontToFrontParameterMap(control_nodes, &real_to_formal_front_parameters);
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> formal_to_real_front_parameters;
for (const auto real_to_formal_front_parameter : real_to_formal_front_parameters) {
for (const auto formal_parameter : real_to_formal_front_parameter.second) {
formal_to_real_front_parameters[formal_parameter].emplace_back(real_to_formal_front_parameter.first);
}
}
for (const auto &front_pair : real_to_formal_front_parameters) {
std::set<AnfNodePtr> invalid_node;
const auto &backend_node =
@ -1071,14 +1128,10 @@ void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGra
}
}
void ControlNodeParser::FetchHostParameterToWeight(const std::vector<AnfNodePtr> &control_nodes) {
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> front_to_front_parameter;
FetchFrontToFrontParameterMap(control_nodes, &front_to_front_parameter);
for (const auto &pair : front_to_front_parameter) {
void ControlNodeParser::FetchHostParameterToWeight(const RealToFormalNode &front_to_front_parameters) {
for (const auto &pair : front_to_front_parameters) {
std::vector<AnfNodePtr> dest_nodes;
FetchWeightbyHostParameter(pair.first, &dest_nodes, front_to_front_parameter);
FetchWeightbyHostParameter(pair.first, &dest_nodes, front_to_front_parameters);
host_parameter_to_weights_[pair.first] = dest_nodes;
}
}
@ -1136,18 +1189,20 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vector<KernelGraphP
}
}
std::vector<KernelWithIndex> ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output,
std::set<AnfNodePtr> *call_nodes,
std::set<AnfNodePtr> *switch_nodes) {
std::vector<KernelWithIndex> backend_outputs;
void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output,
std::set<AnfNodePtr> *call_nodes,
std::set<AnfNodePtr> *switch_nodes,
std::set<KernelWithIndex> *results) {
if (front_output->isa<ValueNode>()) {
backend_outputs.push_back({front_output, 0});
(*results).insert({front_output, 0});
} else if (front_output->isa<Parameter>()) {
// Output is a parameter.
const auto iter = formal_to_real_parameters_.find(front_output);
if (iter != formal_to_real_parameters_.end()) {
backend_outputs.insert(backend_outputs.end(), iter->second.begin(), iter->second.end());
for (const auto &node : iter->second) {
(*results).insert(node);
}
} else {
MS_LOG(EXCEPTION) << "Cannot find backend node for front parameter:" << AnfAlgo::GetNodeDebugString(front_output);
}
@ -1156,16 +1211,14 @@ std::vector<KernelWithIndex> ControlNodeParser::FetchBackendOutputByFrontOutput(
const auto &switch_outputs = FetchOutputBySwitchNode(front_output, call_nodes, switch_nodes);
for (const auto &switch_output : switch_outputs) {
const auto outputs = FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes);
backend_outputs.insert(backend_outputs.end(), outputs.begin(), outputs.end());
FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes, results);
}
} else if (IsCallNode(front_output)) {
// Output is a call.
const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes);
for (const auto &call_output : call_outputs) {
const auto outputs = FetchBackendOutputByFrontOutput(call_output, call_nodes, switch_nodes);
backend_outputs.insert(backend_outputs.end(), outputs.begin(), outputs.end());
FetchBackendOutputByFrontOutput(call_output, call_nodes, switch_nodes, results);
}
} else if (AnfAlgo::CheckPrimitiveType(front_output, prim::kPrimMakeTuple)) {
// Output is a make tuple.
@ -1173,8 +1226,7 @@ std::vector<KernelWithIndex> ControlNodeParser::FetchBackendOutputByFrontOutput(
const auto &inputs = cnode->inputs();
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
const auto outputs = FetchBackendOutputByFrontOutput(inputs[i], call_nodes, switch_nodes);
backend_outputs.insert(backend_outputs.end(), outputs.begin(), outputs.end());
FetchBackendOutputByFrontOutput(inputs[i], call_nodes, switch_nodes, results);
}
} else if (front_output->isa<CNode>()) {
// Output is a kernel.
@ -1182,19 +1234,18 @@ std::vector<KernelWithIndex> ControlNodeParser::FetchBackendOutputByFrontOutput(
if (iter != front_to_backend_kernels_.end()) {
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(iter->second.first, 0);
backend_outputs.emplace_back(output_with_index);
(*results).insert(output_with_index);
} else {
MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output);
}
} else {
MS_LOG(EXCEPTION) << "Invalid front node:" << AnfAlgo::GetNodeDebugString(front_output);
}
return backend_outputs;
}
void ControlNodeParser::FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter,
const AnfNodePtr &formal_parameter) {
void ControlNodeParser::FetchBackendInputNodebyFrontNode(
const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter,
const FrontToBackendNodeWithContext &front_to_backend_parameters) {
if (real_parameter->isa<Parameter>()) {
// Input node is a parameter from host data source actor.
std::set<AnfNodePtr> invalid_inputs;
@ -1205,8 +1256,8 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_
const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0);
if (node_with_index.first->isa<Parameter>()) {
const auto &iter = front_to_backend_parameters_.find(real_parameter);
if (iter == front_to_backend_parameters_.end()) {
const auto &iter = front_to_backend_parameters.find(real_parameter);
if (iter == front_to_backend_parameters.end()) {
MS_LOG(WARNING) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
continue;
}
@ -1224,7 +1275,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_
} else if (IsCallNode(real_parameter)) {
const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter);
for (const auto func_graph : func_graphs) {
FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter);
FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter, front_to_backend_parameters);
}
} else {
// Input node is a cnode.
@ -1237,7 +1288,56 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_
}
}
void ControlNodeParser::FetchBackendInputNode() {
void ControlNodeParser::FetchBackendParameterNode(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const RealToFormalNode &real_to_formal_front_parameters,
const RealToFormalNode &formal_to_real_front_parameters,
FrontToBackendNodeWithContext *front_to_backend_parameters) {
for (size_t i = 0; i < graphs.size(); ++i) {
const auto &graph = graphs[i];
const auto &device_context = device_contexts[i];
if (graph->GetFuncGraph() != root_func_graph_) {
continue;
}
for (const auto &parameter : graph->input_nodes()) {
auto front_node = graph->GetFrontAnfByBackendAnf(parameter);
if (front_node != nullptr && front_node->isa<Parameter>() &&
(*front_to_backend_parameters).find(front_node) == (*front_to_backend_parameters).end()) {
(*front_to_backend_parameters)[front_node] = {parameter, device_context};
}
}
}
for (const auto &control_node_parameter : control_node_parameters_) {
const auto &iter = front_to_backend_parameters_.find(control_node_parameter);
if (iter == front_to_backend_parameters_.end()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for control node parameter:"
<< AnfAlgo::GetNodeDebugString(control_node_parameter);
}
(*front_to_backend_parameters)[control_node_parameter] = iter->second;
}
for (const auto &front_pair : formal_to_real_front_parameters) {
std::set<AnfNodePtr> invalid_node;
const auto &backend_node =
FetchBackendNodeByFrontNode(front_pair.first, real_to_formal_front_parameters, formal_to_real_front_parameters,
(*front_to_backend_parameters), &invalid_node);
if (backend_node.first != nullptr) {
if ((*front_to_backend_parameters).find(front_pair.first) == (*front_to_backend_parameters).end()) {
(*front_to_backend_parameters)[front_pair.first] = backend_node;
}
}
}
}
void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const RealToFormalNode &real_to_formal_front_parameters,
const RealToFormalNode &formal_to_real_front_parameters) {
FrontToBackendNodeWithContext front_to_backend_parameters;
FetchBackendParameterNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters,
&front_to_backend_parameters);
for (const auto &func_graph_to_parameters : func_graph_to_parameters_) {
const auto &func_graph = func_graph_to_parameters.first;
std::vector<AnfNodePtr> graph_inputs;
@ -1259,14 +1359,15 @@ void ControlNodeParser::FetchBackendInputNode() {
}
for (size_t i = 0; i < parameters.size(); ++i) {
FetchBackendInputNodebyFrontNode(parameters[i], graph_inputs[i]);
FetchBackendInputNodebyFrontNode(parameters[i], graph_inputs[i], front_to_backend_parameters);
}
}
}
for (const auto front_to_backend_parameters : front_to_backend_parameters_) {
formal_to_real_parameters_[front_to_backend_parameters.first].push_back(
{front_to_backend_parameters.second.first, 0});
for (const auto parameter_pair : front_to_backend_parameters) {
formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0});
}
for (const auto parameter_pair : front_to_backend_parameters_) {
formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0});
}
}
} // namespace runtime

View File

@ -41,6 +41,7 @@ using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<A
using FuncGraphToParameter = std::unordered_map<FuncGraphPtr, std::vector<std::vector<AnfNodePtr>>>;
using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
// Check whether node is a call node, there are two types of call nodes:
// 1. First input of node is a cnode.
@ -89,7 +90,7 @@ class ControlNodeParser {
// Get all possible input nodes of the output node. When the switch actor is the output, it need to send the node
// which device address belongs, so switch actor need to get all the possible nodes.
std::vector<KernelWithIndex> FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output);
std::set<KernelWithIndex> FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output);
// Get the device context corresponding to the value node.
DeviceContext *GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node);
@ -135,39 +136,50 @@ class ControlNodeParser {
// 2. The parameter from control nodes.
void FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes);
const std::vector<AnfNodePtr> &control_nodes,
const RealToFormalNode &real_to_formal_front_parameters,
const RealToFormalNode &formal_to_real_front_parameters);
// Get the relationship between the front and backend of the executable kernel in all kernel graphs.
void FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts);
// Get inputs of control node which come from the host actor. These inputs generally come from the partial
// nodes and call nodes of the root funcgraph.
std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes,
DeviceContext *device_context);
// Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node,
// and the input of the call node is the input parameter of the corresponding funcgraph.
void FetchFuncGraphToParameter(const std::vector<AnfNodePtr> &control_nodes);
// Get all the front weight parameters related to the weight in the host parameter.
void FetchHostParameterToWeight(const std::vector<AnfNodePtr> &control_nodes);
void FetchHostParameterToWeight(const RealToFormalNode &real_to_formal_front_parameters);
// The relationship between front parameters indicates that the parameter is directly used as the input of the
// funcgraph. There are two situations:
// 1. The parameter is used as the input of the call node,
// 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the
// subsequent call node.
void FetchFrontToFrontParameterMap(const std::vector<AnfNodePtr> &control_nodes,
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
void FetchFrontToFrontParameter(const std::vector<AnfNodePtr> &control_nodes,
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
// Get the number of calls to all subgraphs in the whole funcgraph.
void FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes);
// Get all the kernel graphs where the input node has a call node.
void FetchCallInputKernelGraph(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts);
// Get the relationship of all real and formal nodes in the whole funcgraph.
void FetchBackendInputNode(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const RealToFormalNode &real_to_formal_front_parameters,
const RealToFormalNode &formal_to_real_front_parameters);
// Get the relationship of all real and formal parameters in the whole funcgraph.
void FetchBackendInputNode();
void FetchBackendParameterNode(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const RealToFormalNode &real_to_formal_front_parameters,
const RealToFormalNode &formal_to_real_front_parameters,
FrontToBackendNodeWithContext *front_to_backend_parameters);
// Get all possible input node of real parameter.
void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter);
void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter,
const FrontToBackendNodeWithContext &front_to_backend_parameters);
// Recursive interface, get all Backend node by front_output.
std::vector<KernelWithIndex> FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output,
std::set<AnfNodePtr> *call_nodes,
std::set<AnfNodePtr> *switch_nodes);
void FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, std::set<AnfNodePtr> *call_nodes,
std::set<AnfNodePtr> *switch_nodes, std::set<KernelWithIndex> *results);
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
FrontToBackendNodeWithContext front_to_backend_parameters_;
@ -195,11 +207,14 @@ class ControlNodeParser {
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
// When initializing the weights, all related weights need to be recorded as the same device tensor.
HostParameterToWeight host_parameter_to_weights_;
// The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the
// input of the control node.
NodeWithDeviceContext front_value_nodes_;
// The front output_node is used to link the output actor in multi-branch output scenario.
std::vector<AnfNodePtr> front_output_nodes_;
// The front value node saves all parameters that are not in the kernel graph. These nodes are generally the
// output of subgraph, or the switch condition node.
NodeWithDeviceContext front_parameters_;
// Parameters of control node which come from the host actor.
std::vector<AnfNodePtr> control_node_parameters_;
// The number of calls to func_graph.

View File

@ -1107,7 +1107,7 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
// Collect the parameters.
std::vector<AnfNodePtr> parameters;
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (HasAbstractMonad(inputs[i]) || HasAbstractRef(inputs[i])) {
if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]))) {
continue;
}
parameters.emplace_back(inputs[i]);
@ -2115,8 +2115,9 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
const auto &backend_node = backend_iter->second.first;
auto iter = from_actor->data_node_position_map_.find(input_node);
if (iter == from_actor->data_node_position_map_.end()) {
MS_LOG(EXCEPTION) << "Cannot find data node in data source actor, node:"
<< AnfAlgo::GetNodeDebugString(backend_node);
MS_LOG(EXCEPTION) << "Cannot find data node in data source actor, backend node:"
<< AnfAlgo::GetNodeDebugString(backend_node)
<< " front node:" << AnfAlgo::GetNodeDebugString(input_node);
}
auto op_arrow = std::make_shared<DataArrow>(iter->second, to_actor->GetAID(), to_index);
@ -2133,7 +2134,7 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_
const auto &inputs = actor->input_nodes_;
for (size_t i = 0; i < inputs.size(); ++i) {
auto input = inputs[i];
if (input.first->isa<ValueNode>() || HasAbstractRef(input.first)) {
if (input.first->isa<ValueNode>() || (input.first->isa<Parameter>() && HasAbstractRef(input.first))) {
continue;
}