!17095 Add build control flow actor.

From: @gaoyong10
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-05-30 05:58:38 +08:00 committed by Gitee
commit 303d4857e8
9 changed files with 531 additions and 128 deletions

View File

@ -25,15 +25,13 @@
#include <algorithm>
#include "runtime/framework/device_tensor_store.h"
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/control_node_parser.h"
#include "runtime/hardware/device_context.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/tensor.h"
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
// Gather actor is the entrance of sub funcgraph. Graph input is sent to it and sent to other actors by gather actor.
class GatherActor : public OpActor<DeviceTensor> {

View File

@ -22,18 +22,6 @@
namespace mindspore {
namespace runtime {
constexpr size_t kSwitchInputNum = 4;
constexpr size_t kSwitchCondPos = 1;
constexpr size_t kSwitchPartialNum = 2;
constexpr size_t kSwitchLayerCondPos = 1;
constexpr size_t kSwitchLayerBranchPos = 2;
constexpr size_t kSwitchLayerInputNum = 3;
constexpr size_t kMaxSwitchCondSize = 8;
constexpr size_t kSwitchTrueBranchPos = 2;
constexpr size_t kSwitchFalseBranchPos = 3;
constexpr size_t kPartialFuncGraphPos = 1;
constexpr size_t kPartialInputStartPos = 2;
void SwitchActor::Init() {
// Init output data.
output_data_.resize(output_branch_arrows_.size());

View File

@ -30,6 +30,21 @@
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::session::KernelWithIndex;
constexpr size_t kSwitchInputNum = 4;
constexpr size_t kSwitchCondPos = 1;
constexpr size_t kSwitchPartialNum = 2;
constexpr size_t kSwitchLayerCondPos = 1;
constexpr size_t kSwitchLayerBranchPos = 2;
constexpr size_t kSwitchLayerInputNum = 3;
constexpr size_t kMaxSwitchCondSize = 8;
constexpr size_t kSwitchTrueBranchPos = 2;
constexpr size_t kSwitchFalseBranchPos = 3;
constexpr size_t kPartialFuncGraphPos = 1;
constexpr size_t kPartialInputStartPos = 2;
constexpr size_t kCallInputStartPos = 1;
constexpr size_t kMakeTupleInputStartPos = 1;
// Switch actor is used to execute the branch according to the input condition.
// Switch and SwitchLayer node will be converted to switch actor.

View File

@ -0,0 +1,275 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/control_node_parser.h"
#include "runtime/framework/actor/switch_actor.h"
namespace mindspore {
namespace runtime {
bool ControlNodeParser::IsCallNode(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;
}
const auto &cnode = node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
}
FuncGraphPtr ControlNodeParser::GetFuncGraphFromPartial(const AnfNodePtr &node) {
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
}
std::vector<FuncGraphPtr> ControlNodeParser::FetchFuncGraphbyCallNode(const CNodePtr &node) {
std::vector<FuncGraphPtr> func_graphs;
const auto &call_inputs = node->inputs();
if (call_inputs[0]->isa<CNode>()) {
const auto &cnode = call_inputs[0]->cast<CNodePtr>();
const auto &cnode_inputs = cnode->inputs();
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
for (size_t i = kSwitchTrueBranchPos; i < cnode_inputs.size(); ++i) {
if (IsPrimitiveCNode(cnode_inputs[i], prim::kPrimPartial)) {
func_graphs.emplace_back(GetFuncGraphFromPartial(cnode_inputs[i]));
}
}
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer) &&
AnfAlgo::CheckPrimitiveType(cnode_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple)) {
const auto &tuple_inputs = cnode_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < tuple_inputs.size(); ++i) {
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
func_graphs.emplace_back(GetFuncGraphFromPartial(cnode_inputs[i]));
} else if (IsValueNode<FuncGraph>(tuple_inputs[i])) {
func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(tuple_inputs[i]));
}
}
} else {
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
}
} else if (call_inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(call_inputs[0])) {
func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(call_inputs[0]));
} else {
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
}
return func_graphs;
}
std::vector<AnfNodePtr> ControlNodeParser::FetchFuncGraphOutput(const FuncGraphPtr &func_graph,
std::vector<AnfNodePtr> *call_nodes) {
std::vector<AnfNodePtr> outputs;
const auto &output = func_graph->output();
const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0);
if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) {
return outputs;
}
if (!IsCallNode(real_output.first)) {
outputs.push_back(real_output.first);
return outputs;
}
(*call_nodes).push_back(real_output.first);
const auto &call_cnode = real_output.first->cast<CNodePtr>();
std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(call_cnode);
for (const auto &graph : func_graphs) {
auto single_outputs = FetchFuncGraphOutput(graph, call_nodes);
outputs.insert(outputs.end(), single_outputs.begin(), single_outputs.end());
}
return outputs;
}
std::pair<AnfNodePtr, DeviceContext *> ControlNodeParser::FetchBackendNodeByFrontNode(
const AnfNodePtr &front_node, const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &front_to_front_parameter,
const std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>> &front_to_backend_parameter,
std::set<AnfNodePtr> *invalid_node) {
// Check whether the front_node has been looked for.
if ((*invalid_node).find(front_node) != (*invalid_node).end()) {
return std::pair<AnfNodePtr, DeviceContext *>();
}
(*invalid_node).insert(front_node);
const auto front_to_backend_iter = front_to_backend_parameter.find(front_node);
if (front_to_backend_iter != front_to_backend_parameter.end()) {
return front_to_backend_iter->second;
}
const auto &front_to_front_iter = front_to_front_parameter.find(front_node);
if (front_to_front_iter == front_to_front_parameter.end()) {
return std::pair<AnfNodePtr, DeviceContext *>();
}
for (const auto &next_node : front_to_front_iter->second) {
auto banckend_node =
FetchBackendNodeByFrontNode(next_node, front_to_front_parameter, front_to_backend_parameter, invalid_node);
if (banckend_node.first != nullptr) {
return banckend_node;
}
}
return std::pair<AnfNodePtr, DeviceContext *>();
}
void ControlNodeParser::FetchFrontToFrontParameterMap(
const std::vector<AnfNodePtr> &control_nodes,
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter) {
// Function used to collect the input of call node.
const auto &call_input_parse = [front_to_front_parameter](const std::vector<AnfNodePtr> &parameters,
const std::vector<AnfNodePtr> &call_inputs,
const size_t call_input_start_pos) {
for (size_t i = 0; i < call_inputs.size(); ++i) {
if (call_inputs[i]->isa<Parameter>()) {
(*front_to_front_parameter)[call_inputs[i]].push_back(parameters[i + call_input_start_pos]);
}
}
};
// Function used to collect the input of partial node.
const auto &partial_input_parse = [call_input_parse, front_to_front_parameter](
const AnfNodePtr &partial_node, const std::vector<AnfNodePtr> &call_inputs) {
const auto &cnode = partial_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[kPartialFuncGraphPos]);
const auto &parameters = func_graph->parameters();
for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
if (inputs[i]->isa<Parameter>()) {
(*front_to_front_parameter)[inputs[i]].push_back(parameters[i - kPartialInputStartPos]);
}
}
call_input_parse(parameters, call_inputs, inputs.size() - kPartialInputStartPos);
};
// Function used to collect the input of switch node.
const auto &switch_input_parse = [&](const AnfNodePtr &switch_node, const std::vector<AnfNodePtr> &call_inputs) {
CNodePtr cnode = switch_node->cast<CNodePtr>();
const auto &switch_inputs = cnode->inputs();
if (AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitch)) {
// Parse the switch node. The switch node has two partial node inputs.
if (AnfAlgo::CheckPrimitiveType(switch_inputs[kSwitchTrueBranchPos], prim::kPrimPartial)) {
partial_input_parse(switch_inputs[kSwitchTrueBranchPos], call_inputs);
partial_input_parse(switch_inputs[kSwitchFalseBranchPos], call_inputs);
}
} else {
// Parse the switchlayer node. The switchlayer node has a maketuple node input, which is a tuple of funcgraphs.
// call_inputs will be the input of these funcgraphs.
const auto &tuple_node = switch_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>();
const auto &tuple_inputs = tuple_node->inputs();
for (const auto &input : tuple_inputs) {
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimPartial)) {
partial_input_parse(input, call_inputs);
} else {
auto func_graph = GetValueNode<FuncGraphPtr>(input);
call_input_parse(func_graph->parameters(), call_inputs, 0);
}
}
}
};
for (const auto &node : control_nodes) {
CNodePtr cnode = node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
// Call node which the first input node is a valuenode of funcgraph.
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
const auto &parameters = func_graph->parameters();
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (inputs[i]->isa<Parameter>()) {
(*front_to_front_parameter)[inputs[i]].push_back(parameters[i - kCallInputStartPos]);
}
}
} else if (inputs[0]->isa<CNode>()) {
// Call node which the first input node is a switch or switchlayer node.
if ((!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch)) &&
(!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer))) {
MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:"
<< AnfAlgo::GetNodeDebugString(inputs[0]);
}
std::vector<AnfNodePtr> call_inputs;
call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end());
switch_input_parse(inputs[0], call_inputs);
}
}
}
std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes) {
std::vector<AnfNodePtr> parameters;
for (const auto &control_node : control_nodes) {
CNodePtr cnode = control_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
break;
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
if (inputs[i]->isa<Parameter>()) {
parameters.emplace_back(inputs[i]);
}
}
} else if (cnode->input(0)->isa<CNode>() || IsValueNode<FuncGraph>(cnode->input(0))) {
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (inputs[i]->isa<Parameter>()) {
parameters.emplace_back(inputs[i]);
}
}
}
}
return parameters;
}
std::vector<AnfNodePtr> ControlNodeParser::FetchAllBranchOutputs(const FuncGraphPtr &func_graph) {
std::vector<AnfNodePtr> call_nodes;
return FetchFuncGraphOutput(func_graph, &call_nodes);
}
void ControlNodeParser::FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes,
FrontToBackendNodeWithContext *front_to_backend_parameter) {
if (graphs.size() != device_contexts.size()) {
MS_LOG(EXCEPTION) << "Graph num is not equal to device context num.";
}
// Fetch the mapping relationship between front parameters and backend parameters in the kernel graphs.
for (size_t i = 0; i < graphs.size(); ++i) {
const auto &graph = graphs[i];
auto device_context = device_contexts[i];
for (const auto &parameter : graph->parameters()) {
auto front_node = graph->GetFrontAnfByBackendAnf(parameter);
if (front_node != nullptr && front_node->isa<Parameter>() &&
(*front_to_backend_parameter).find(front_node) == (*front_to_backend_parameter).end()) {
(*front_to_backend_parameter)[front_node] = {parameter, device_context};
}
}
}
// Fetch the mapping relationship between front parameters and backend parameters in the control nodes. First
// fetch the mapping relationship of the frontparameter. When the input of the call node or the partial node
// is a parameter node, it means that the parameter is directly transmitted. If a parameter does not have a
// corresponding backend node, then recursively find whether the front parameter corresponding to the parameter
// has one.
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> front_to_front_parameter;
FetchFrontToFrontParameterMap(control_nodes, &front_to_front_parameter);
for (const auto &front_pair : front_to_front_parameter) {
std::set<AnfNodePtr> invalid_node;
const auto &backend_node = FetchBackendNodeByFrontNode(front_pair.first, front_to_front_parameter,
*front_to_backend_parameter, &invalid_node);
if (backend_node.first != nullptr) {
(*front_to_backend_parameter)[front_pair.first] = backend_node;
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -20,7 +20,7 @@
#include <vector>
#include <string>
#include <memory>
#include <tuple>
#include <set>
#include <utility>
#include <unordered_map>
#include <algorithm>
@ -32,100 +32,62 @@ namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::session::KernelWithIndex;
// The meaning of switch node output tuple: 1. switch node 2. output branch id 3. output index.
using SwitchNodeOutput = std::tuple<AnfNodePtr, size_t, size_t>;
// The output arrow info: 1. from node index 2.to node 3. to node index.
using NodeOutputInfo = std::tuple<size_t, AnfNodePtr, size_t>;
// External input of kernel graph, the key means the front node of input
// and value vector is pairs of from node and to node.
using KernelGraphExternInput = std::unordered_map<AnfNodePtr, std::vector<std::pair<KernelWithIndex, KernelWithIndex>>>;
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
struct PairHash {
template <class T1, class T2>
std::size_t operator()(const std::pair<T1, T2> &p) const {
auto h1 = std::hash<T1>{}(p.first);
auto h2 = std::hash<T2>{}(p.second);
return h1 ^ h2;
}
};
// Get all possible outputs of funcgraph. Search recursively by the input of the return node of the funcgraph.
// If the input is a call node, enter all the funcgraphs it called until the input of the non-call node is found
// and return all of the output node.
std::vector<AnfNodePtr> GetAllBranchOutputs(const FuncGraphPtr &func_graph);
// ControlNodeParser is used to parse control nodes, and get the edges between nodes. Call node is used to
// implement the call relationship between funcgraphs, the actual parameters are connected to the call node,
// and the call node then calls the corresponding funcgraph, sends the actual parameters to next nodes
// according to the relationship between the actual parameters and formal parameters.
// From the function of the call node, the structure of the edge can be split into two parts:
// the relationship between the output nodes and the formal parameters, and relationship between formal parameters
// and input nodes. And then they are connected to become the final edge.
// Therefore, the analysis is mainly divided into 2 steps:
// 1. Get all input and output relationship with formal parameters;
// 2. Connect all input and output to edges.
// ControlNodeParser is a series of tool functions used to parse control nodes.
class ControlNodeParser {
public:
ControlNodeParser() = default;
~ControlNodeParser() = default;
ControlNodeParser(const ControlNodeParser &) = delete;
ControlNodeParser &operator=(const ControlNodeParser &) = delete;
// Fetch all the relationships between front parameters and backend parameters.The front parameters
// include two parts:
// 1. The parameter from kernel graph.
// 2. The parameter from control nodes.
static void FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes,
FrontToBackendNodeWithContext *front_to_backend_parameter);
// Analyze the relationship between switch and kernel nodes.
// Parameter kernel_graph_input_ indicates that in a multi-graph case, parameters of the subgraph
// should be the passed in when called by main graph, rather than directly sent by the input, so it
// needs to be connected when parsing the control node.
// The result of parse is the edge between nodes, which is stored in member variables.
void Parse(const std::vector<AnfNodePtr> &control_nodes, const KernelGraphExternInput &kernel_graph_input_);
// Get inputs of control node which come from the host actor. These inputs generally come from the partial
// nodes and call nodes of the root funcgraph.
static std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
// Get the output of funcgraph, usually there is only one output node, In the control flow, there are
// multiple branch outputs, there will be multiple output nodes.
static std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph);
private:
friend class GraphScheduler;
// Check whether node is a call node, there are two types of call nodes:
// 1. First input of node is a cnode.
// 2. First input of node is a funcgraph value node.
static bool IsCallNode(const AnfNodePtr &node);
void ParseCall(const AnfNodePtr &node);
void ParseSwitch(const AnfNodePtr &node, const std::vector<AnfNodePtr> &inputs_on_call);
void ParseSwitchLayer(const AnfNodePtr &node, const std::vector<AnfNodePtr> &inputs_on_call);
void ParsePartial(const AnfNodePtr &node, const std::vector<AnfNodePtr> &switch_inputs, const size_t branch_id,
const std::vector<AnfNodePtr> &inputs_on_call);
void ParseInput(const AnfNodePtr &from_node, const AnfNodePtr &to_node, size_t to_index);
// Parse input which is a call node, This means that we need to find the output of the funcgraph called by
// the call node as the input of to_node.
void ParseCallInput(const CNodePtr &from_node, const AnfNodePtr &to_node, size_t to_index);
// Get the funcgraph in partial node.
static FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node);
// Get all inputs of switch nodes, inputs_on_call is the inputs which was inputs of call node which the switch
// node connected.
std::vector<AnfNodePtr> GetSwitchInput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &inputs_on_call);
// Find all funcgraphs that the call node will call.
static std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const CNodePtr &node);
// Connect the input and output of the call node to get the final edge.
void LinkInputAndOutput();
// Link the formal parameter to its final actual parameter in member variables parameter_to_arguments_.
// For example, if we have a map like {{a, b}, {b, c}, {c, d}}, final we will get {{a, d}, {b, d}, {c, d}}.
void LinkParameterAndArgument();
// Recursively find all inputs corresponding to node.
void GetOutputNode(const AnfNodePtr &node, std::vector<KernelWithIndex> *inputs);
// Find the output of the funcgraph, if the output is a call node, return the output of the funcgraph
// called by the call node.
static std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph,
std::vector<AnfNodePtr> *call_nodes);
// Relationship between formal parameter and actual parameter
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> actual_to_formal_parameters_;
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> formal_to_actual_parameters_;
// Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding
// backend parameter, then recursively find the backend parameters of other front parameters corresponding to the
// front_node.
static std::pair<AnfNodePtr, DeviceContext *> FetchBackendNodeByFrontNode(
const AnfNodePtr &front_node,
const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &front_to_front_parameter,
const std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>> &front_to_backend_parameter,
std::set<AnfNodePtr> *invalid_node);
// In control nodes, edge is a structure like kernel output --> formal parameter --> kernel input.
// In the parsing process, the edge is divided into two parts, input and output:
// input represents the relationship between formal parameters and kernel input,
// output represents the relationship between kernel output and formal parameters.
// In order to merge input and output into edge, both of them are stored in map and use parameter as the key.
// All inputs.
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> parameter_to_input_;
// Three kinds of output.
// output of switch node.
std::unordered_map<AnfNodePtr, std::vector<SwitchNodeOutput>> parameter_to_switch_out_;
// output of kernel node.
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> parameter_to_kernel_out_;
// parameters in root funcgraph.
std::vector<AnfNodePtr> parameters_;
// Final edges.
std::unordered_map<AnfNodePtr, std::vector<NodeOutputInfo>> kernel_outputs_;
std::unordered_map<std::pair<AnfNodePtr, size_t>, std::vector<NodeOutputInfo>, PairHash> switch_outputs_;
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> parameter_out_;
// The relationship between front parameters indicates that the parameter is directly used as the input of the
// funcgraph. There are two situations:
// 1. The parameter is used as the input of the call node,
// 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the
// subsequent call node.
static void FetchFrontToFrontParameterMap(
const std::vector<AnfNodePtr> &control_nodes,
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
};
} // namespace runtime
} // namespace mindspore

View File

@ -301,6 +301,20 @@ TensorPtr FetchInputTensor(const GraphCompilerInfo &graph_compiler_info, size_t
}
return nullptr;
}
void PrepareDataForHostDataSourceActor(const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
const AnfNodePtr &node, const TensorPtr &tensor,
std::vector<TensorPtr> *host_tensors) {
if (std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address()) != nullptr) {
return;
}
// Fill the host tensors for non weighted parameters.
const auto &iter = data_node_position_map.find(node);
if (iter != data_node_position_map.end()) {
(*host_tensors)[iter->second] = tensor;
}
}
} // namespace
GraphScheduler::~GraphScheduler() {
@ -390,6 +404,10 @@ void GraphScheduler::Schedule(const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(switch_actor);
actors.emplace_back(static_cast<ActorReference>(switch_actor));
}
for (auto &gather_actor : actor_set->gather_actors_) {
MS_EXCEPTION_IF_NULL(gather_actor);
actors.emplace_back(static_cast<ActorReference>(gather_actor));
}
for (auto &copy_actor : actor_set->copy_actors_) {
MS_EXCEPTION_IF_NULL(copy_actor);
actors.emplace_back(static_cast<ActorReference>(copy_actor));
@ -444,11 +462,8 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
PrepareDataForWeightNode(input_node, input_tensor, device_context, graph);
} else if (IsHostQueueDSActor(input_node, graph, input_tensor)) {
MS_EXCEPTION_IF_NULL(host_data_source_actor);
// Fill the host tensors for non weighted parameters.
const auto &iter = host_data_source_actor->data_node_position_map_.find(input_node);
if (iter != host_data_source_actor->data_node_position_map_.end()) {
host_tensors[iter->second] = input_tensor;
}
PrepareDataForHostDataSourceActor(host_data_source_actor->data_node_position_map_, input_node, input_tensor,
&host_tensors);
}
}
@ -461,7 +476,24 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
}
}
// 3.Prepare the data of host tensor queue(non weighted parameters of graph).
// 3.Fill host tensors for non weighted parameters which belongs to control node.
std::vector<AnfNodePtr> control_node_parameters =
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
const auto &tensors = input_tensors.back();
const auto &parameters = graph_compiler_info.origin_parameters_order_;
for (size_t j = 0; j < control_node_parameters.size(); ++j) {
const auto &input_node = control_node_parameters[j];
const auto &input_tensor = tensors[j];
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
} else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) {
MS_EXCEPTION_IF_NULL(host_data_source_actor);
PrepareDataForHostDataSourceActor(host_data_source_actor->data_node_position_map_, input_node, input_tensor,
&host_tensors);
}
}
// 4.Prepare the data of host tensor queue(non weighted parameters of graph).
if (host_data_source_actor != nullptr) {
const auto &host_tensor_queue = FetchHostQueue(actor_set->name_);
MS_EXCEPTION_IF_NULL(host_tensor_queue);
@ -531,6 +563,8 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info,
actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info, strategy);
actor_set->output_actor_ = BuildOutputActor(graph_compiler_info, strategy);
actor_set->switch_actors_ = BuildSwitchActor(graph_compiler_info);
actor_set->gather_actors_ = BuildGatherActor(graph_compiler_info);
return actor_set;
}
@ -705,6 +739,38 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
device_queue_ds_actor->kernel_info_ = static_cast<device::KernelInfo *>((*iter)->kernel_info());
}
}
const auto &front_to_backend_parameter = graph_compiler_info.front_to_backend_parameters_;
// Initialize the parameter in the control node, first get all the front parameters in the control node, then find
// the corresponding backend parameter from the map, and insert it into the host data source actor
std::vector<AnfNodePtr> control_node_parameters =
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
for (const auto parameter : control_node_parameters) {
auto backend_iter = front_to_backend_parameter.find(parameter);
if (backend_iter == front_to_backend_parameter.end()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
}
if (host_queue_ds_actor == nullptr) {
auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
host_queue_ds_actor = std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, host_queue);
InsertActor(host_queue_ds_actor.get());
data_source_actors.emplace_back(host_queue_ds_actor);
}
const auto &backend_node = backend_iter->second.first;
auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node);
if (iter != host_queue_ds_actor->data_nodes_.end()) {
host_queue_ds_actor->data_node_position_map_.emplace(parameter, iter - host_queue_ds_actor->data_nodes_.begin());
} else {
host_queue_ds_actor->data_node_position_map_.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.first);
host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.second);
}
}
return data_source_actors;
}
@ -773,6 +839,55 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorS
return no_input_kernel_actors;
}
std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
std::vector<SwitchActorPtr> switch_actors;
for (const auto &control_node : graph_compiler_info.control_nodes_) {
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
auto actor_name = control_node->fullname_with_scope();
auto switch_actor = std::make_shared<SwitchActor>(actor_name, control_node->cast<CNodePtr>());
switch_actor->Initialize();
InsertActor(switch_actor.get());
switch_actors.emplace_back(switch_actor);
}
}
return switch_actors;
}
std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
std::vector<GatherActorPtr> gather_actors;
bool is_main_return = true;
// Each funcgraph has a return node, get the funcgraph from the return node, and create a gather actor.
for (const auto &control_node : graph_compiler_info.control_nodes_) {
// Root funcgraph does not need to create a gather actor.
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
if (is_main_return) {
is_main_return = false;
continue;
}
auto func_graph = control_node->func_graph();
auto actor_name = func_graph->ToString();
std::vector<AnfNodePtr> parameters;
std::copy_if(
func_graph->parameters().begin(), func_graph->parameters().end(), parameters.begin(),
[](const AnfNodePtr &parameter) { return !AnfAlgo::IsParameterWeight(parameter->cast<ParameterPtr>()); });
auto loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor";
auto actor = FetchActor(loop_count_actor_name);
if (actor == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find loop count actor by name:" << loop_count_actor_name;
}
auto gather_actor = std::make_shared<GatherActor>(actor_name, parameters, actor->GetAID());
InsertActor(gather_actor.get());
gather_actors.emplace_back(gather_actor);
}
}
return gather_actors;
}
void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const ActorSet *actor_set, const KernelGraphPtr &graph,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx, const TensorPtr &tensor) {

View File

@ -30,6 +30,7 @@
#include "runtime/framework/actor/kernel_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/framework/actor/gather_actor.h"
#include "runtime/framework/actor/copy_actor.h"
#include "runtime/hardware/device_context.h"
#include "backend/session/kernel_graph.h"
@ -38,6 +39,7 @@
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::session::KernelGraph;
using mindspore::session::KernelWithIndex;
using KernelMapPosition = std::map<KernelWithIndex, size_t, session::KernelWithIndexCmp>;
using ActorInfo = std::string;
@ -62,13 +64,15 @@ enum class GraphExecutionStrategy {
// The control node is used to link graphs in the control flow scenario.
// The origin parameters order is used to correspond to the input args.
// The origin outputs order is used to correspond to the output args.
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
// The front output_node is used to link the output actor in multi-branch output scenario.
struct GraphCompilerInfo {
GraphCompilerInfo(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
const std::vector<std::vector<int64_t> *> &tensors_mask,
const std::vector<std::vector<TensorPtr> *> &input_tensors,
const std::vector<AnfNodePtr> &control_nodes,
const std::vector<AnfNodePtr> &origin_parameters_order,
const KernelMapPosition &origin_outputs_order, const std::string &name)
GraphCompilerInfo(
const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
const std::vector<std::vector<int64_t> *> &tensors_mask, const std::vector<std::vector<TensorPtr> *> &input_tensors,
const std::vector<AnfNodePtr> &control_nodes, const std::vector<AnfNodePtr> &origin_parameters_order,
const KernelMapPosition &origin_outputs_order, const FrontToBackendNodeWithContext &front_to_backend_parameters,
const std::vector<AnfNodePtr> &front_output_nodes, const size_t outputs_num, const std::string &name)
: graphs_(graphs),
device_contexts_(device_contexts),
tensors_mask_(tensors_mask),
@ -76,6 +80,9 @@ struct GraphCompilerInfo {
control_nodes_(control_nodes),
origin_parameters_order_(origin_parameters_order),
origin_outputs_order_(origin_outputs_order),
front_to_backend_parameters_(front_to_backend_parameters),
front_output_nodes_(front_output_nodes),
outputs_num_(outputs_num),
name_(name) {}
std::vector<KernelGraphPtr> graphs_;
std::vector<DeviceContext *> device_contexts_;
@ -84,6 +91,9 @@ struct GraphCompilerInfo {
std::vector<AnfNodePtr> control_nodes_;
std::vector<AnfNodePtr> origin_parameters_order_;
KernelMapPosition origin_outputs_order_;
FrontToBackendNodeWithContext front_to_backend_parameters_;
std::vector<AnfNodePtr> front_output_nodes_;
size_t outputs_num_;
std::string name_;
};
@ -92,6 +102,9 @@ struct GraphCompilerInfo {
// The data source actor is used to obtain data and process them into device tensors, and send them to kernel actor.
// The kernel actor is used to receive the device tensors to luanch kernel. Specifically notice the no input
// kernel actor, it means that this actor has no input device tensor, need be triggered externally.
// The switch actor is used to run different branches in the control flow scenario.
// The gather actor is used to collect the inputs of graph and send branch id to loop count actor in multi-branch
// output scenario.
// The copy actor is used to convert the device tensor between the different device kernel.
// The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
// and decide whether to loop execution by loop count.
@ -103,6 +116,7 @@ struct ActorSet {
// No input kernel actors need be triggered specifically.
std::vector<KernelActorPtr> no_input_kernel_actors_;
std::vector<SwitchActorPtr> switch_actors_;
std::vector<GatherActorPtr> gather_actors_;
std::vector<CopyActorPtr> copy_actors_;
LoopCountActorPtr loop_count_actor_{nullptr};
OutputActorPtr output_actor_{nullptr};
@ -160,6 +174,8 @@ class GraphScheduler {
LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info, GraphExecutionStrategy strategy);
OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info, GraphExecutionStrategy strategy);
std::vector<KernelActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set);
std::vector<SwitchActorPtr> BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info);
std::vector<GatherActorPtr> BuildGatherActor(const GraphCompilerInfo &graph_compiler_info);
// Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
// previous graph and the head of next graph.

View File

@ -151,6 +151,18 @@ void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs)
MS_LOG(WARNING) << "Invalid input type.";
}
}
// Insert the front_node related tensor in the input_tensor.
void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
std::vector<tensor::TensorPtr> *input_tensor) {
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
if (iter == parameters.end()) {
(*input_tensor).emplace_back(nullptr);
return;
}
auto position = iter - parameters.begin();
PushInputTensor(args[position], input_tensor);
}
} // namespace
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
@ -397,22 +409,28 @@ VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
// Transform args to input tensors.
// Input tensors of the graph.
std::vector<std::vector<tensor::TensorPtr>> input_tensors;
for (const auto &kernel_graph : graph_compiler_info.graphs_) {
std::vector<tensor::TensorPtr> input_tensor;
for (const auto &input_node : kernel_graph->input_nodes()) {
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
const auto &iter = std::find(origin_parameters.begin(), origin_parameters.end(), front_node);
if (iter == origin_parameters.end()) {
input_tensor.emplace_back(nullptr);
continue;
}
auto position = IntToSize(std::distance(origin_parameters.begin(), iter));
PushInputTensor(args[position], &input_tensor);
PushTensor(args, origin_parameters, front_node, &input_tensor);
}
input_tensors.emplace_back(input_tensor);
}
// Input tensors of the control node.
std::vector<tensor::TensorPtr> input_tensor;
// Get inputs of control node which come from the host actor.
std::vector<AnfNodePtr> control_node_parameters =
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
for (const auto &parameter : control_node_parameters) {
PushTensor(args, origin_parameters, parameter, &input_tensor);
}
input_tensors.emplace_back(input_tensor);
// Run in the pynative mode.
VectorRef outputs;
auto ms_context = MsContext::GetInstance();
@ -465,19 +483,33 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
name.append("_").append(std::to_string(graph_id_to_context.first));
}
// Get all the outputs. In control flow, there may be multiple branch output.
runtime::KernelMapPosition outputs_order;
size_t position = 0;
const auto &outputs = AnfAlgo::GetAllOutput(root_graph->output(), {prim::kPrimTupleGetItem});
for (const auto &output : outputs) {
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, true);
MS_EXCEPTION_IF_NULL(output_with_index.first);
outputs_order.emplace(output_with_index, position++);
size_t outputs_num = 0;
const auto &all_branch_output = ControlNodeParser::FetchAllBranchOutputs(root_graph);
for (const auto &branch_output : all_branch_output) {
size_t position = 0;
const auto &outputs = AnfAlgo::GetAllOutput(branch_output, {prim::kPrimTupleGetItem});
outputs_num = outputs.size();
for (const auto &output : outputs) {
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, false);
MS_EXCEPTION_IF_NULL(output_with_index.first);
outputs_order.emplace(output_with_index, position++);
}
}
// Fetch all the relationships between front parameters and backend parameters which will be used to
// build and link actors.
FrontToBackendNodeWithContext front_to_backend_parameter;
ControlNodeParser::FetchFrontToBackendParameterMap(graphs, device_contexts, control_nodes_,
&front_to_backend_parameter);
std::vector<std::vector<int64_t> *> tensors_mask;
std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask, input_tensors, control_nodes_,
root_graph->parameters(), outputs_order, name);
root_graph->parameters(), outputs_order, front_to_backend_parameter,
all_branch_output, outputs_num, name);
}
std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
@ -504,9 +536,10 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
std::vector<std::vector<int64_t> *> tensors_mask_list(1, const_cast<std::vector<int64_t> *>(tensors_mask));
std::vector<std::vector<TensorPtr> *> input_tensors_list(1,
const_cast<std::vector<tensor::TensorPtr> *>(input_tensors));
return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask_list, input_tensors_list,
std::vector<AnfNodePtr>(), std::vector<AnfNodePtr>(), outputs_order, name);
std::vector<AnfNodePtr>(), std::vector<AnfNodePtr>(), outputs_order,
FrontToBackendNodeWithContext(), std::vector<AnfNodePtr>(),
outputs_order.size(), name);
}
VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const std::vector<int64_t> *tensors_mask,

View File

@ -38,7 +38,8 @@ using OpRunInfo = session::OpRunInfo;
using DeviceContext = device::DeviceContext;
using ActorInfo = runtime::ActorInfo;
using GraphCompilerInfo = runtime::GraphCompilerInfo;
using ControlNodeParser = runtime::ControlNodeParser;
using FrontToBackendNodeWithContext = runtime::FrontToBackendNodeWithContext;
enum SwitchCondStatus {
kCondOk = 0,
kCondAlreadyRun,