forked from mindspore-Ecosystem/mindspore
!17095 Add build control flow actor.
From: @gaoyong10 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
303d4857e8
|
@ -25,15 +25,13 @@
|
|||
#include <algorithm>
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/control_node_parser.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
|
||||
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
|
||||
// Gather actor is the entrance of sub funcgraph. Graph input is sent to it and sent to other actors by gather actor.
|
||||
class GatherActor : public OpActor<DeviceTensor> {
|
||||
|
|
|
@ -22,18 +22,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
constexpr size_t kSwitchInputNum = 4;
|
||||
constexpr size_t kSwitchCondPos = 1;
|
||||
constexpr size_t kSwitchPartialNum = 2;
|
||||
constexpr size_t kSwitchLayerCondPos = 1;
|
||||
constexpr size_t kSwitchLayerBranchPos = 2;
|
||||
constexpr size_t kSwitchLayerInputNum = 3;
|
||||
constexpr size_t kMaxSwitchCondSize = 8;
|
||||
constexpr size_t kSwitchTrueBranchPos = 2;
|
||||
constexpr size_t kSwitchFalseBranchPos = 3;
|
||||
constexpr size_t kPartialFuncGraphPos = 1;
|
||||
constexpr size_t kPartialInputStartPos = 2;
|
||||
|
||||
void SwitchActor::Init() {
|
||||
// Init output data.
|
||||
output_data_.resize(output_branch_arrows_.size());
|
||||
|
|
|
@ -30,6 +30,21 @@
|
|||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
using mindspore::session::KernelWithIndex;
|
||||
|
||||
constexpr size_t kSwitchInputNum = 4;
|
||||
constexpr size_t kSwitchCondPos = 1;
|
||||
constexpr size_t kSwitchPartialNum = 2;
|
||||
constexpr size_t kSwitchLayerCondPos = 1;
|
||||
constexpr size_t kSwitchLayerBranchPos = 2;
|
||||
constexpr size_t kSwitchLayerInputNum = 3;
|
||||
constexpr size_t kMaxSwitchCondSize = 8;
|
||||
constexpr size_t kSwitchTrueBranchPos = 2;
|
||||
constexpr size_t kSwitchFalseBranchPos = 3;
|
||||
constexpr size_t kPartialFuncGraphPos = 1;
|
||||
constexpr size_t kPartialInputStartPos = 2;
|
||||
constexpr size_t kCallInputStartPos = 1;
|
||||
constexpr size_t kMakeTupleInputStartPos = 1;
|
||||
|
||||
// Switch actor is used to execute the branch according to the input condition.
|
||||
// Switch and SwitchLayer node will be converted to switch actor.
|
||||
|
|
|
@ -0,0 +1,275 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/framework/control_node_parser.h"
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
||||
bool ControlNodeParser::IsCallNode(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
|
||||
}
|
||||
|
||||
FuncGraphPtr ControlNodeParser::GetFuncGraphFromPartial(const AnfNodePtr &node) {
|
||||
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
|
||||
return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
|
||||
}
|
||||
|
||||
std::vector<FuncGraphPtr> ControlNodeParser::FetchFuncGraphbyCallNode(const CNodePtr &node) {
|
||||
std::vector<FuncGraphPtr> func_graphs;
|
||||
const auto &call_inputs = node->inputs();
|
||||
|
||||
if (call_inputs[0]->isa<CNode>()) {
|
||||
const auto &cnode = call_inputs[0]->cast<CNodePtr>();
|
||||
const auto &cnode_inputs = cnode->inputs();
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
for (size_t i = kSwitchTrueBranchPos; i < cnode_inputs.size(); ++i) {
|
||||
if (IsPrimitiveCNode(cnode_inputs[i], prim::kPrimPartial)) {
|
||||
func_graphs.emplace_back(GetFuncGraphFromPartial(cnode_inputs[i]));
|
||||
}
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer) &&
|
||||
AnfAlgo::CheckPrimitiveType(cnode_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple)) {
|
||||
const auto &tuple_inputs = cnode_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
|
||||
|
||||
for (size_t i = 1; i < tuple_inputs.size(); ++i) {
|
||||
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
|
||||
func_graphs.emplace_back(GetFuncGraphFromPartial(cnode_inputs[i]));
|
||||
} else if (IsValueNode<FuncGraph>(tuple_inputs[i])) {
|
||||
func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(tuple_inputs[i]));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
|
||||
}
|
||||
} else if (call_inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(call_inputs[0])) {
|
||||
func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(call_inputs[0]));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
|
||||
}
|
||||
return func_graphs;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> ControlNodeParser::FetchFuncGraphOutput(const FuncGraphPtr &func_graph,
|
||||
std::vector<AnfNodePtr> *call_nodes) {
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
const auto &output = func_graph->output();
|
||||
const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0);
|
||||
if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) {
|
||||
return outputs;
|
||||
}
|
||||
if (!IsCallNode(real_output.first)) {
|
||||
outputs.push_back(real_output.first);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
(*call_nodes).push_back(real_output.first);
|
||||
const auto &call_cnode = real_output.first->cast<CNodePtr>();
|
||||
std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(call_cnode);
|
||||
for (const auto &graph : func_graphs) {
|
||||
auto single_outputs = FetchFuncGraphOutput(graph, call_nodes);
|
||||
outputs.insert(outputs.end(), single_outputs.begin(), single_outputs.end());
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::pair<AnfNodePtr, DeviceContext *> ControlNodeParser::FetchBackendNodeByFrontNode(
|
||||
const AnfNodePtr &front_node, const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &front_to_front_parameter,
|
||||
const std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>> &front_to_backend_parameter,
|
||||
std::set<AnfNodePtr> *invalid_node) {
|
||||
// Check whether the front_node has been looked for.
|
||||
if ((*invalid_node).find(front_node) != (*invalid_node).end()) {
|
||||
return std::pair<AnfNodePtr, DeviceContext *>();
|
||||
}
|
||||
(*invalid_node).insert(front_node);
|
||||
|
||||
const auto front_to_backend_iter = front_to_backend_parameter.find(front_node);
|
||||
if (front_to_backend_iter != front_to_backend_parameter.end()) {
|
||||
return front_to_backend_iter->second;
|
||||
}
|
||||
|
||||
const auto &front_to_front_iter = front_to_front_parameter.find(front_node);
|
||||
if (front_to_front_iter == front_to_front_parameter.end()) {
|
||||
return std::pair<AnfNodePtr, DeviceContext *>();
|
||||
}
|
||||
for (const auto &next_node : front_to_front_iter->second) {
|
||||
auto banckend_node =
|
||||
FetchBackendNodeByFrontNode(next_node, front_to_front_parameter, front_to_backend_parameter, invalid_node);
|
||||
if (banckend_node.first != nullptr) {
|
||||
return banckend_node;
|
||||
}
|
||||
}
|
||||
return std::pair<AnfNodePtr, DeviceContext *>();
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchFrontToFrontParameterMap(
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter) {
|
||||
// Function used to collect the input of call node.
|
||||
const auto &call_input_parse = [front_to_front_parameter](const std::vector<AnfNodePtr> ¶meters,
|
||||
const std::vector<AnfNodePtr> &call_inputs,
|
||||
const size_t call_input_start_pos) {
|
||||
for (size_t i = 0; i < call_inputs.size(); ++i) {
|
||||
if (call_inputs[i]->isa<Parameter>()) {
|
||||
(*front_to_front_parameter)[call_inputs[i]].push_back(parameters[i + call_input_start_pos]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Function used to collect the input of partial node.
|
||||
const auto &partial_input_parse = [call_input_parse, front_to_front_parameter](
|
||||
const AnfNodePtr &partial_node, const std::vector<AnfNodePtr> &call_inputs) {
|
||||
const auto &cnode = partial_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[kPartialFuncGraphPos]);
|
||||
const auto ¶meters = func_graph->parameters();
|
||||
for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->isa<Parameter>()) {
|
||||
(*front_to_front_parameter)[inputs[i]].push_back(parameters[i - kPartialInputStartPos]);
|
||||
}
|
||||
}
|
||||
call_input_parse(parameters, call_inputs, inputs.size() - kPartialInputStartPos);
|
||||
};
|
||||
|
||||
// Function used to collect the input of switch node.
|
||||
const auto &switch_input_parse = [&](const AnfNodePtr &switch_node, const std::vector<AnfNodePtr> &call_inputs) {
|
||||
CNodePtr cnode = switch_node->cast<CNodePtr>();
|
||||
const auto &switch_inputs = cnode->inputs();
|
||||
if (AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitch)) {
|
||||
// Parse the switch node. The switch node has two partial node inputs.
|
||||
if (AnfAlgo::CheckPrimitiveType(switch_inputs[kSwitchTrueBranchPos], prim::kPrimPartial)) {
|
||||
partial_input_parse(switch_inputs[kSwitchTrueBranchPos], call_inputs);
|
||||
partial_input_parse(switch_inputs[kSwitchFalseBranchPos], call_inputs);
|
||||
}
|
||||
} else {
|
||||
// Parse the switchlayer node. The switchlayer node has a maketuple node input, which is a tuple of funcgraphs.
|
||||
// call_inputs will be the input of these funcgraphs.
|
||||
const auto &tuple_node = switch_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>();
|
||||
const auto &tuple_inputs = tuple_node->inputs();
|
||||
for (const auto &input : tuple_inputs) {
|
||||
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimPartial)) {
|
||||
partial_input_parse(input, call_inputs);
|
||||
} else {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(input);
|
||||
call_input_parse(func_graph->parameters(), call_inputs, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (const auto &node : control_nodes) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
|
||||
// Call node which the first input node is a valuenode of funcgraph.
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
const auto ¶meters = func_graph->parameters();
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->isa<Parameter>()) {
|
||||
(*front_to_front_parameter)[inputs[i]].push_back(parameters[i - kCallInputStartPos]);
|
||||
}
|
||||
}
|
||||
} else if (inputs[0]->isa<CNode>()) {
|
||||
// Call node which the first input node is a switch or switchlayer node.
|
||||
if ((!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch)) &&
|
||||
(!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer))) {
|
||||
MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:"
|
||||
<< AnfAlgo::GetNodeDebugString(inputs[0]);
|
||||
}
|
||||
std::vector<AnfNodePtr> call_inputs;
|
||||
call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end());
|
||||
switch_input_parse(inputs[0], call_inputs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
|
||||
for (const auto &control_node : control_nodes) {
|
||||
CNodePtr cnode = control_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
|
||||
break;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
|
||||
for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->isa<Parameter>()) {
|
||||
parameters.emplace_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
} else if (cnode->input(0)->isa<CNode>() || IsValueNode<FuncGraph>(cnode->input(0))) {
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->isa<Parameter>()) {
|
||||
parameters.emplace_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parameters;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> ControlNodeParser::FetchAllBranchOutputs(const FuncGraphPtr &func_graph) {
|
||||
std::vector<AnfNodePtr> call_nodes;
|
||||
return FetchFuncGraphOutput(func_graph, &call_nodes);
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
FrontToBackendNodeWithContext *front_to_backend_parameter) {
|
||||
if (graphs.size() != device_contexts.size()) {
|
||||
MS_LOG(EXCEPTION) << "Graph num is not equal to device context num.";
|
||||
}
|
||||
|
||||
// Fetch the mapping relationship between front parameters and backend parameters in the kernel graphs.
|
||||
for (size_t i = 0; i < graphs.size(); ++i) {
|
||||
const auto &graph = graphs[i];
|
||||
auto device_context = device_contexts[i];
|
||||
for (const auto ¶meter : graph->parameters()) {
|
||||
auto front_node = graph->GetFrontAnfByBackendAnf(parameter);
|
||||
if (front_node != nullptr && front_node->isa<Parameter>() &&
|
||||
(*front_to_backend_parameter).find(front_node) == (*front_to_backend_parameter).end()) {
|
||||
(*front_to_backend_parameter)[front_node] = {parameter, device_context};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch the mapping relationship between front parameters and backend parameters in the control nodes. First
|
||||
// fetch the mapping relationship of the frontparameter. When the input of the call node or the partial node
|
||||
// is a parameter node, it means that the parameter is directly transmitted. If a parameter does not have a
|
||||
// corresponding backend node, then recursively find whether the front parameter corresponding to the parameter
|
||||
// has one.
|
||||
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> front_to_front_parameter;
|
||||
FetchFrontToFrontParameterMap(control_nodes, &front_to_front_parameter);
|
||||
|
||||
for (const auto &front_pair : front_to_front_parameter) {
|
||||
std::set<AnfNodePtr> invalid_node;
|
||||
const auto &backend_node = FetchBackendNodeByFrontNode(front_pair.first, front_to_front_parameter,
|
||||
*front_to_backend_parameter, &invalid_node);
|
||||
if (backend_node.first != nullptr) {
|
||||
(*front_to_backend_parameter)[front_pair.first] = backend_node;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -20,7 +20,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
|
@ -32,100 +32,62 @@ namespace runtime {
|
|||
using mindspore::device::DeviceContext;
|
||||
using mindspore::session::KernelWithIndex;
|
||||
|
||||
// The meaning of switch node output tuple: 1. switch node 2. output branch id 3. output index.
|
||||
using SwitchNodeOutput = std::tuple<AnfNodePtr, size_t, size_t>;
|
||||
// The output arrow info: 1. from node index 2.to node 3. to node index.
|
||||
using NodeOutputInfo = std::tuple<size_t, AnfNodePtr, size_t>;
|
||||
// External input of kernel graph, the key means the front node of input
|
||||
// and value vector is pairs of from node and to node.
|
||||
using KernelGraphExternInput = std::unordered_map<AnfNodePtr, std::vector<std::pair<KernelWithIndex, KernelWithIndex>>>;
|
||||
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
|
||||
struct PairHash {
|
||||
template <class T1, class T2>
|
||||
std::size_t operator()(const std::pair<T1, T2> &p) const {
|
||||
auto h1 = std::hash<T1>{}(p.first);
|
||||
auto h2 = std::hash<T2>{}(p.second);
|
||||
return h1 ^ h2;
|
||||
}
|
||||
};
|
||||
|
||||
// Get all possible outputs of funcgraph. Search recursively by the input of the return node of the funcgraph.
|
||||
// If the input is a call node, enter all the funcgraphs it called until the input of the non-call node is found
|
||||
// and return all of the output node.
|
||||
std::vector<AnfNodePtr> GetAllBranchOutputs(const FuncGraphPtr &func_graph);
|
||||
|
||||
// ControlNodeParser is used to parse control nodes, and get the edges between nodes. Call node is used to
|
||||
// implement the call relationship between funcgraphs, the actual parameters are connected to the call node,
|
||||
// and the call node then calls the corresponding funcgraph, sends the actual parameters to next nodes
|
||||
// according to the relationship between the actual parameters and formal parameters.
|
||||
// From the function of the call node, the structure of the edge can be split into two parts:
|
||||
// the relationship between the output nodes and the formal parameters, and relationship between formal parameters
|
||||
// and input nodes. And then they are connected to become the final edge.
|
||||
// Therefore, the analysis is mainly divided into 2 steps:
|
||||
// 1. Get all input and output relationship with formal parameters;
|
||||
// 2. Connect all input and output to edges.
|
||||
// ControlNodeParser is a series of tool functions used to parse control nodes.
|
||||
class ControlNodeParser {
|
||||
public:
|
||||
ControlNodeParser() = default;
|
||||
~ControlNodeParser() = default;
|
||||
ControlNodeParser(const ControlNodeParser &) = delete;
|
||||
ControlNodeParser &operator=(const ControlNodeParser &) = delete;
|
||||
// Fetch all the relationships between front parameters and backend parameters.The front parameters
|
||||
// include two parts:
|
||||
// 1. The parameter from kernel graph.
|
||||
// 2. The parameter from control nodes.
|
||||
static void FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
FrontToBackendNodeWithContext *front_to_backend_parameter);
|
||||
|
||||
// Analyze the relationship between switch and kernel nodes.
|
||||
// Parameter kernel_graph_input_ indicates that in a multi-graph case, parameters of the subgraph
|
||||
// should be the passed in when called by main graph, rather than directly sent by the input, so it
|
||||
// needs to be connected when parsing the control node.
|
||||
// The result of parse is the edge between nodes, which is stored in member variables.
|
||||
void Parse(const std::vector<AnfNodePtr> &control_nodes, const KernelGraphExternInput &kernel_graph_input_);
|
||||
// Get inputs of control node which come from the host actor. These inputs generally come from the partial
|
||||
// nodes and call nodes of the root funcgraph.
|
||||
static std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// Get the output of funcgraph, usually there is only one output node, In the control flow, there are
|
||||
// multiple branch outputs, there will be multiple output nodes.
|
||||
static std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
// Check whether node is a call node, there are two types of call nodes:
|
||||
// 1. First input of node is a cnode.
|
||||
// 2. First input of node is a funcgraph value node.
|
||||
static bool IsCallNode(const AnfNodePtr &node);
|
||||
|
||||
void ParseCall(const AnfNodePtr &node);
|
||||
void ParseSwitch(const AnfNodePtr &node, const std::vector<AnfNodePtr> &inputs_on_call);
|
||||
void ParseSwitchLayer(const AnfNodePtr &node, const std::vector<AnfNodePtr> &inputs_on_call);
|
||||
void ParsePartial(const AnfNodePtr &node, const std::vector<AnfNodePtr> &switch_inputs, const size_t branch_id,
|
||||
const std::vector<AnfNodePtr> &inputs_on_call);
|
||||
void ParseInput(const AnfNodePtr &from_node, const AnfNodePtr &to_node, size_t to_index);
|
||||
// Parse input which is a call node, This means that we need to find the output of the funcgraph called by
|
||||
// the call node as the input of to_node.
|
||||
void ParseCallInput(const CNodePtr &from_node, const AnfNodePtr &to_node, size_t to_index);
|
||||
// Get the funcgraph in partial node.
|
||||
static FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node);
|
||||
|
||||
// Get all inputs of switch nodes, inputs_on_call is the inputs which was inputs of call node which the switch
|
||||
// node connected.
|
||||
std::vector<AnfNodePtr> GetSwitchInput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &inputs_on_call);
|
||||
// Find all funcgraphs that the call node will call.
|
||||
static std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const CNodePtr &node);
|
||||
|
||||
// Connect the input and output of the call node to get the final edge.
|
||||
void LinkInputAndOutput();
|
||||
// Link the formal parameter to its final actual parameter in member variables parameter_to_arguments_.
|
||||
// For example, if we have a map like {{a, b}, {b, c}, {c, d}}, final we will get {{a, d}, {b, d}, {c, d}}.
|
||||
void LinkParameterAndArgument();
|
||||
// Recursively find all inputs corresponding to node.
|
||||
void GetOutputNode(const AnfNodePtr &node, std::vector<KernelWithIndex> *inputs);
|
||||
// Find the output of the funcgraph, if the output is a call node, return the output of the funcgraph
|
||||
// called by the call node.
|
||||
static std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph,
|
||||
std::vector<AnfNodePtr> *call_nodes);
|
||||
|
||||
// Relationship between formal parameter and actual parameter
|
||||
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> actual_to_formal_parameters_;
|
||||
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> formal_to_actual_parameters_;
|
||||
// Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding
|
||||
// backend parameter, then recursively find the backend parameters of other front parameters corresponding to the
|
||||
// front_node.
|
||||
static std::pair<AnfNodePtr, DeviceContext *> FetchBackendNodeByFrontNode(
|
||||
const AnfNodePtr &front_node,
|
||||
const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &front_to_front_parameter,
|
||||
const std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>> &front_to_backend_parameter,
|
||||
std::set<AnfNodePtr> *invalid_node);
|
||||
|
||||
// In control nodes, edge is a structure like kernel output --> formal parameter --> kernel input.
|
||||
// In the parsing process, the edge is divided into two parts, input and output:
|
||||
// input represents the relationship between formal parameters and kernel input,
|
||||
// output represents the relationship between kernel output and formal parameters.
|
||||
// In order to merge input and output into edge, both of them are stored in map and use parameter as the key.
|
||||
// All inputs.
|
||||
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> parameter_to_input_;
|
||||
// Three kinds of output.
|
||||
// output of switch node.
|
||||
std::unordered_map<AnfNodePtr, std::vector<SwitchNodeOutput>> parameter_to_switch_out_;
|
||||
// output of kernel node.
|
||||
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> parameter_to_kernel_out_;
|
||||
// parameters in root funcgraph.
|
||||
std::vector<AnfNodePtr> parameters_;
|
||||
|
||||
// Final edges.
|
||||
std::unordered_map<AnfNodePtr, std::vector<NodeOutputInfo>> kernel_outputs_;
|
||||
std::unordered_map<std::pair<AnfNodePtr, size_t>, std::vector<NodeOutputInfo>, PairHash> switch_outputs_;
|
||||
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> parameter_out_;
|
||||
// The relationship between front parameters indicates that the parameter is directly used as the input of the
|
||||
// funcgraph. There are two situations:
|
||||
// 1. The parameter is used as the input of the call node,
|
||||
// 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the
|
||||
// subsequent call node.
|
||||
static void FetchFrontToFrontParameterMap(
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -301,6 +301,20 @@ TensorPtr FetchInputTensor(const GraphCompilerInfo &graph_compiler_info, size_t
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void PrepareDataForHostDataSourceActor(const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
|
||||
const AnfNodePtr &node, const TensorPtr &tensor,
|
||||
std::vector<TensorPtr> *host_tensors) {
|
||||
if (std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address()) != nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Fill the host tensors for non weighted parameters.
|
||||
const auto &iter = data_node_position_map.find(node);
|
||||
if (iter != data_node_position_map.end()) {
|
||||
(*host_tensors)[iter->second] = tensor;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
GraphScheduler::~GraphScheduler() {
|
||||
|
@ -390,6 +404,10 @@ void GraphScheduler::Schedule(const ActorSet *actor_set) {
|
|||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
actors.emplace_back(static_cast<ActorReference>(switch_actor));
|
||||
}
|
||||
for (auto &gather_actor : actor_set->gather_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
actors.emplace_back(static_cast<ActorReference>(gather_actor));
|
||||
}
|
||||
for (auto ©_actor : actor_set->copy_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(copy_actor);
|
||||
actors.emplace_back(static_cast<ActorReference>(copy_actor));
|
||||
|
@ -444,11 +462,8 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
|
|||
PrepareDataForWeightNode(input_node, input_tensor, device_context, graph);
|
||||
} else if (IsHostQueueDSActor(input_node, graph, input_tensor)) {
|
||||
MS_EXCEPTION_IF_NULL(host_data_source_actor);
|
||||
// Fill the host tensors for non weighted parameters.
|
||||
const auto &iter = host_data_source_actor->data_node_position_map_.find(input_node);
|
||||
if (iter != host_data_source_actor->data_node_position_map_.end()) {
|
||||
host_tensors[iter->second] = input_tensor;
|
||||
}
|
||||
PrepareDataForHostDataSourceActor(host_data_source_actor->data_node_position_map_, input_node, input_tensor,
|
||||
&host_tensors);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -461,7 +476,24 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
|
|||
}
|
||||
}
|
||||
|
||||
// 3.Prepare the data of host tensor queue(non weighted parameters of graph).
|
||||
// 3.Fill host tensors for non weighted parameters which belongs to control node.
|
||||
std::vector<AnfNodePtr> control_node_parameters =
|
||||
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
|
||||
const auto &tensors = input_tensors.back();
|
||||
const auto ¶meters = graph_compiler_info.origin_parameters_order_;
|
||||
for (size_t j = 0; j < control_node_parameters.size(); ++j) {
|
||||
const auto &input_node = control_node_parameters[j];
|
||||
const auto &input_tensor = tensors[j];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsPersistentDeviceTensor(input_node)) {
|
||||
} else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) {
|
||||
MS_EXCEPTION_IF_NULL(host_data_source_actor);
|
||||
PrepareDataForHostDataSourceActor(host_data_source_actor->data_node_position_map_, input_node, input_tensor,
|
||||
&host_tensors);
|
||||
}
|
||||
}
|
||||
|
||||
// 4.Prepare the data of host tensor queue(non weighted parameters of graph).
|
||||
if (host_data_source_actor != nullptr) {
|
||||
const auto &host_tensor_queue = FetchHostQueue(actor_set->name_);
|
||||
MS_EXCEPTION_IF_NULL(host_tensor_queue);
|
||||
|
@ -531,6 +563,8 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info,
|
|||
actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
|
||||
actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info, strategy);
|
||||
actor_set->output_actor_ = BuildOutputActor(graph_compiler_info, strategy);
|
||||
actor_set->switch_actors_ = BuildSwitchActor(graph_compiler_info);
|
||||
actor_set->gather_actors_ = BuildGatherActor(graph_compiler_info);
|
||||
|
||||
return actor_set;
|
||||
}
|
||||
|
@ -705,6 +739,38 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
device_queue_ds_actor->kernel_info_ = static_cast<device::KernelInfo *>((*iter)->kernel_info());
|
||||
}
|
||||
}
|
||||
|
||||
const auto &front_to_backend_parameter = graph_compiler_info.front_to_backend_parameters_;
|
||||
|
||||
// Initialize the parameter in the control node, first get all the front parameters in the control node, then find
|
||||
// the corresponding backend parameter from the map, and insert it into the host data source actor
|
||||
std::vector<AnfNodePtr> control_node_parameters =
|
||||
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
|
||||
for (const auto parameter : control_node_parameters) {
|
||||
auto backend_iter = front_to_backend_parameter.find(parameter);
|
||||
if (backend_iter == front_to_backend_parameter.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
|
||||
}
|
||||
|
||||
if (host_queue_ds_actor == nullptr) {
|
||||
auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
|
||||
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
|
||||
host_queue_ds_actor = std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, host_queue);
|
||||
InsertActor(host_queue_ds_actor.get());
|
||||
data_source_actors.emplace_back(host_queue_ds_actor);
|
||||
}
|
||||
|
||||
const auto &backend_node = backend_iter->second.first;
|
||||
auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node);
|
||||
|
||||
if (iter != host_queue_ds_actor->data_nodes_.end()) {
|
||||
host_queue_ds_actor->data_node_position_map_.emplace(parameter, iter - host_queue_ds_actor->data_nodes_.begin());
|
||||
} else {
|
||||
host_queue_ds_actor->data_node_position_map_.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
|
||||
host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.first);
|
||||
host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.second);
|
||||
}
|
||||
}
|
||||
return data_source_actors;
|
||||
}
|
||||
|
||||
|
@ -773,6 +839,55 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorS
|
|||
return no_input_kernel_actors;
|
||||
}
|
||||
|
||||
std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
std::vector<SwitchActorPtr> switch_actors;
|
||||
|
||||
for (const auto &control_node : graph_compiler_info.control_nodes_) {
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
|
||||
auto actor_name = control_node->fullname_with_scope();
|
||||
auto switch_actor = std::make_shared<SwitchActor>(actor_name, control_node->cast<CNodePtr>());
|
||||
switch_actor->Initialize();
|
||||
InsertActor(switch_actor.get());
|
||||
switch_actors.emplace_back(switch_actor);
|
||||
}
|
||||
}
|
||||
return switch_actors;
|
||||
}
|
||||
|
||||
std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
std::vector<GatherActorPtr> gather_actors;
|
||||
|
||||
bool is_main_return = true;
|
||||
// Each funcgraph has a return node, get the funcgraph from the return node, and create a gather actor.
|
||||
for (const auto &control_node : graph_compiler_info.control_nodes_) {
|
||||
// Root funcgraph does not need to create a gather actor.
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
|
||||
if (is_main_return) {
|
||||
is_main_return = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto func_graph = control_node->func_graph();
|
||||
auto actor_name = func_graph->ToString();
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
std::copy_if(
|
||||
func_graph->parameters().begin(), func_graph->parameters().end(), parameters.begin(),
|
||||
[](const AnfNodePtr ¶meter) { return !AnfAlgo::IsParameterWeight(parameter->cast<ParameterPtr>()); });
|
||||
|
||||
auto loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor";
|
||||
auto actor = FetchActor(loop_count_actor_name);
|
||||
if (actor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find loop count actor by name:" << loop_count_actor_name;
|
||||
}
|
||||
auto gather_actor = std::make_shared<GatherActor>(actor_name, parameters, actor->GetAID());
|
||||
InsertActor(gather_actor.get());
|
||||
gather_actors.emplace_back(gather_actor);
|
||||
}
|
||||
}
|
||||
return gather_actors;
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const ActorSet *actor_set, const KernelGraphPtr &graph,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx, const TensorPtr &tensor) {
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "runtime/framework/actor/kernel_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "runtime/framework/actor/gather_actor.h"
|
||||
#include "runtime/framework/actor/copy_actor.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
@ -38,6 +39,7 @@
|
|||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
using mindspore::session::KernelGraph;
|
||||
using mindspore::session::KernelWithIndex;
|
||||
using KernelMapPosition = std::map<KernelWithIndex, size_t, session::KernelWithIndexCmp>;
|
||||
using ActorInfo = std::string;
|
||||
|
@ -62,13 +64,15 @@ enum class GraphExecutionStrategy {
|
|||
// The control node is used to link graphs in the control flow scenario.
|
||||
// The origin parameters order is used to correspond to the input args.
|
||||
// The origin outputs order is used to correspond to the output args.
|
||||
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
|
||||
// The front output_node is used to link the output actor in multi-branch output scenario.
|
||||
struct GraphCompilerInfo {
|
||||
GraphCompilerInfo(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<std::vector<int64_t> *> &tensors_mask,
|
||||
const std::vector<std::vector<TensorPtr> *> &input_tensors,
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
const std::vector<AnfNodePtr> &origin_parameters_order,
|
||||
const KernelMapPosition &origin_outputs_order, const std::string &name)
|
||||
GraphCompilerInfo(
|
||||
const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<std::vector<int64_t> *> &tensors_mask, const std::vector<std::vector<TensorPtr> *> &input_tensors,
|
||||
const std::vector<AnfNodePtr> &control_nodes, const std::vector<AnfNodePtr> &origin_parameters_order,
|
||||
const KernelMapPosition &origin_outputs_order, const FrontToBackendNodeWithContext &front_to_backend_parameters,
|
||||
const std::vector<AnfNodePtr> &front_output_nodes, const size_t outputs_num, const std::string &name)
|
||||
: graphs_(graphs),
|
||||
device_contexts_(device_contexts),
|
||||
tensors_mask_(tensors_mask),
|
||||
|
@ -76,6 +80,9 @@ struct GraphCompilerInfo {
|
|||
control_nodes_(control_nodes),
|
||||
origin_parameters_order_(origin_parameters_order),
|
||||
origin_outputs_order_(origin_outputs_order),
|
||||
front_to_backend_parameters_(front_to_backend_parameters),
|
||||
front_output_nodes_(front_output_nodes),
|
||||
outputs_num_(outputs_num),
|
||||
name_(name) {}
|
||||
std::vector<KernelGraphPtr> graphs_;
|
||||
std::vector<DeviceContext *> device_contexts_;
|
||||
|
@ -84,6 +91,9 @@ struct GraphCompilerInfo {
|
|||
std::vector<AnfNodePtr> control_nodes_;
|
||||
std::vector<AnfNodePtr> origin_parameters_order_;
|
||||
KernelMapPosition origin_outputs_order_;
|
||||
FrontToBackendNodeWithContext front_to_backend_parameters_;
|
||||
std::vector<AnfNodePtr> front_output_nodes_;
|
||||
size_t outputs_num_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
|
@ -92,6 +102,9 @@ struct GraphCompilerInfo {
|
|||
// The data source actor is used to obtain data and process them into device tensors, and send them to kernel actor.
|
||||
// The kernel actor is used to receive the device tensors to luanch kernel. Specifically notice the no input
|
||||
// kernel actor, it means that this actor has no input device tensor, need be triggered externally.
|
||||
// The switch actor is used to run different branches in the control flow scenario.
|
||||
// The gather actor is used to collect the inputs of graph and send branch id to loop count actor in multi-branch
|
||||
// output scenario.
|
||||
// The copy actor is used to convert the device tensor between the different device kernel.
|
||||
// The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
|
||||
// and decide whether to loop execution by loop count.
|
||||
|
@ -103,6 +116,7 @@ struct ActorSet {
|
|||
// No input kernel actors need be triggered specifically.
|
||||
std::vector<KernelActorPtr> no_input_kernel_actors_;
|
||||
std::vector<SwitchActorPtr> switch_actors_;
|
||||
std::vector<GatherActorPtr> gather_actors_;
|
||||
std::vector<CopyActorPtr> copy_actors_;
|
||||
LoopCountActorPtr loop_count_actor_{nullptr};
|
||||
OutputActorPtr output_actor_{nullptr};
|
||||
|
@ -160,6 +174,8 @@ class GraphScheduler {
|
|||
LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info, GraphExecutionStrategy strategy);
|
||||
OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info, GraphExecutionStrategy strategy);
|
||||
std::vector<KernelActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set);
|
||||
std::vector<SwitchActorPtr> BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
std::vector<GatherActorPtr> BuildGatherActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
||||
// Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
|
||||
// previous graph and the head of next graph.
|
||||
|
|
|
@ -151,6 +151,18 @@ void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs)
|
|||
MS_LOG(WARNING) << "Invalid input type.";
|
||||
}
|
||||
}
|
||||
|
||||
// Insert the front_node related tensor in the input_tensor.
|
||||
void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶meters, const AnfNodePtr &front_node,
|
||||
std::vector<tensor::TensorPtr> *input_tensor) {
|
||||
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
|
||||
if (iter == parameters.end()) {
|
||||
(*input_tensor).emplace_back(nullptr);
|
||||
return;
|
||||
}
|
||||
auto position = iter - parameters.begin();
|
||||
PushInputTensor(args[position], input_tensor);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
|
||||
|
@ -397,22 +409,28 @@ VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &
|
|||
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
|
||||
|
||||
// Transform args to input tensors.
|
||||
// Input tensors of the graph.
|
||||
std::vector<std::vector<tensor::TensorPtr>> input_tensors;
|
||||
for (const auto &kernel_graph : graph_compiler_info.graphs_) {
|
||||
std::vector<tensor::TensorPtr> input_tensor;
|
||||
for (const auto &input_node : kernel_graph->input_nodes()) {
|
||||
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
|
||||
const auto &iter = std::find(origin_parameters.begin(), origin_parameters.end(), front_node);
|
||||
if (iter == origin_parameters.end()) {
|
||||
input_tensor.emplace_back(nullptr);
|
||||
continue;
|
||||
}
|
||||
auto position = IntToSize(std::distance(origin_parameters.begin(), iter));
|
||||
PushInputTensor(args[position], &input_tensor);
|
||||
PushTensor(args, origin_parameters, front_node, &input_tensor);
|
||||
}
|
||||
input_tensors.emplace_back(input_tensor);
|
||||
}
|
||||
|
||||
// Input tensors of the control node.
|
||||
std::vector<tensor::TensorPtr> input_tensor;
|
||||
|
||||
// Get inputs of control node which come from the host actor.
|
||||
std::vector<AnfNodePtr> control_node_parameters =
|
||||
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
|
||||
for (const auto ¶meter : control_node_parameters) {
|
||||
PushTensor(args, origin_parameters, parameter, &input_tensor);
|
||||
}
|
||||
input_tensors.emplace_back(input_tensor);
|
||||
|
||||
// Run in the pynative mode.
|
||||
VectorRef outputs;
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -465,19 +483,33 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
|||
name.append("_").append(std::to_string(graph_id_to_context.first));
|
||||
}
|
||||
|
||||
// Get all the outputs. In control flow, there may be multiple branch output.
|
||||
runtime::KernelMapPosition outputs_order;
|
||||
size_t position = 0;
|
||||
const auto &outputs = AnfAlgo::GetAllOutput(root_graph->output(), {prim::kPrimTupleGetItem});
|
||||
for (const auto &output : outputs) {
|
||||
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
outputs_order.emplace(output_with_index, position++);
|
||||
size_t outputs_num = 0;
|
||||
const auto &all_branch_output = ControlNodeParser::FetchAllBranchOutputs(root_graph);
|
||||
for (const auto &branch_output : all_branch_output) {
|
||||
size_t position = 0;
|
||||
const auto &outputs = AnfAlgo::GetAllOutput(branch_output, {prim::kPrimTupleGetItem});
|
||||
outputs_num = outputs.size();
|
||||
|
||||
for (const auto &output : outputs) {
|
||||
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
outputs_order.emplace(output_with_index, position++);
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch all the relationships between front parameters and backend parameters which will be used to
|
||||
// build and link actors.
|
||||
FrontToBackendNodeWithContext front_to_backend_parameter;
|
||||
ControlNodeParser::FetchFrontToBackendParameterMap(graphs, device_contexts, control_nodes_,
|
||||
&front_to_backend_parameter);
|
||||
|
||||
std::vector<std::vector<int64_t> *> tensors_mask;
|
||||
std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
|
||||
return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask, input_tensors, control_nodes_,
|
||||
root_graph->parameters(), outputs_order, name);
|
||||
root_graph->parameters(), outputs_order, front_to_backend_parameter,
|
||||
all_branch_output, outputs_num, name);
|
||||
}
|
||||
|
||||
std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
|
||||
|
@ -504,9 +536,10 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
|
|||
std::vector<std::vector<int64_t> *> tensors_mask_list(1, const_cast<std::vector<int64_t> *>(tensors_mask));
|
||||
std::vector<std::vector<TensorPtr> *> input_tensors_list(1,
|
||||
const_cast<std::vector<tensor::TensorPtr> *>(input_tensors));
|
||||
|
||||
return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask_list, input_tensors_list,
|
||||
std::vector<AnfNodePtr>(), std::vector<AnfNodePtr>(), outputs_order, name);
|
||||
std::vector<AnfNodePtr>(), std::vector<AnfNodePtr>(), outputs_order,
|
||||
FrontToBackendNodeWithContext(), std::vector<AnfNodePtr>(),
|
||||
outputs_order.size(), name);
|
||||
}
|
||||
|
||||
VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const std::vector<int64_t> *tensors_mask,
|
||||
|
|
|
@ -38,7 +38,8 @@ using OpRunInfo = session::OpRunInfo;
|
|||
using DeviceContext = device::DeviceContext;
|
||||
using ActorInfo = runtime::ActorInfo;
|
||||
using GraphCompilerInfo = runtime::GraphCompilerInfo;
|
||||
|
||||
using ControlNodeParser = runtime::ControlNodeParser;
|
||||
using FrontToBackendNodeWithContext = runtime::FrontToBackendNodeWithContext;
|
||||
enum SwitchCondStatus {
|
||||
kCondOk = 0,
|
||||
kCondAlreadyRun,
|
||||
|
|
Loading…
Reference in New Issue