!18362 Link value node for control flow and organize control node parser.

Merge pull request !18362 from gaoyong10/new_runtime5
This commit is contained in:
i-robot 2021-06-16 09:33:27 +08:00 committed by Gitee
commit b324a49078
10 changed files with 650 additions and 271 deletions

View File

@ -94,8 +94,8 @@ void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph,
if (parameters[i]->isa<Parameter>()) {
// Input node is a parameter from host data source actor.
std::vector<AnfNodePtr> invalid_inputs;
std::vector<AnfNodePtr> front_inputs = ControlNodeParser::FetchInputNodeByParameter(
parameters[i], origin_parameters_order, &invalid_inputs, func_graph_to_parameters);
std::vector<AnfNodePtr> front_inputs =
FetchInputNodeByParameter(parameters[i], origin_parameters_order, &invalid_inputs, func_graph_to_parameters);
for (const auto &front_input : front_inputs) {
const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0);

View File

@ -34,6 +34,8 @@
namespace mindspore {
namespace runtime {
constexpr size_t kReturnInputPos = 1;
// Gather actor is the entrance of sub funcgraph. Graph input is sent to it and sent to other actors by gather actor.
class GatherActor : public OpActor<DeviceTensor> {
public:

View File

@ -15,6 +15,7 @@
*/
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "mindrt/include/async/async.h"
#include "abstract/utils.h"
@ -57,7 +58,6 @@ void SwitchActor::Initialize() {
} else {
InitSwitchLayer();
}
input_datas_num_ = input_nodes_.size();
}
void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
@ -68,8 +68,18 @@ void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
// [0] ValueNode<Primitive> kPartial.
// [1] ValueNode<FuncGraphPtr>.
// [2..] Inputs.
auto node_inputs = cnode->inputs();
branch_func_graph_[branch_id] = GetValueNode<FuncGraphPtr>(node_inputs[kPartialFuncGraphPos]);
const auto &node_inputs = cnode->inputs();
if (node_inputs.size() <= kPartialFuncGraphPos) {
MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(cnode);
}
const auto &func_graph = GetValueNode<FuncGraphPtr>(node_inputs[kPartialFuncGraphPos]);
if (func_graph->output()->isa<ValueNode>()) {
AddInput(func_graph->output(), branch_id);
return;
}
branch_func_graph_[branch_id] = func_graph;
for (size_t j = kPartialInputStartPos; j < node_inputs.size(); ++j) {
AddInput(node_inputs[j], branch_id);
}
@ -93,8 +103,11 @@ void SwitchActor::InitSwitch() {
branch_inputs_pos_.resize(kSwitchPartialNum);
branch_func_graph_.resize(kSwitchPartialNum);
output_branch_arrows_.resize(kSwitchPartialNum);
input_nodes_.push_back(inputs[kSwitchCondPos]);
output_branch_result_arrows_.resize(kSwitchPartialNum);
output_branch_control_arrows_.resize(kSwitchPartialNum);
input_nodes_.push_back(inputs[kSwitchCondPos]);
input_datas_num_++;
// Init the two branches of switch node.
InitPartial(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
InitPartial(inputs[kSwitchTrueBranchPos], static_cast<size_t>(true));
@ -118,6 +131,8 @@ void SwitchActor::InitSwitchLayer() {
branch_inputs_pos_.resize(branch_nodes.size() - 1);
branch_func_graph_.resize(branch_nodes.size() - 1);
output_branch_arrows_.resize(branch_nodes.size() - 1);
output_branch_result_arrows_.resize(branch_nodes.size() - 1);
output_branch_control_arrows_.resize(branch_nodes.size() - 1);
// Parse all branches.
for (size_t i = 1; i < branch_nodes.size(); ++i) {
@ -147,14 +162,23 @@ size_t SwitchActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
branch_total_inputs_[branch].push_back(node);
if (node->isa<ValueNode>() && (!HasAbstractMonad(node))) {
device_tensor_store_keys_.push_back({input_nodes_.size(), node.get()});
branch_inputs_pos_[branch].push_back(input_nodes_.size());
input_nodes_.push_back(node);
return;
}
// Switch actor only receives parameter, updatestate node output is U, need to be skipped.
if (IsPersistentDeviceTensor(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
return;
}
auto iter = find(input_nodes_.begin(), input_nodes_.end(), node);
if (iter == input_nodes_.end()) {
branch_inputs_pos_[branch].push_back(input_nodes_.size());
input_nodes_.push_back(node);
++input_datas_num_;
} else {
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
}
@ -208,8 +232,7 @@ bool SwitchActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
input_device_tensors_.resize(input_datas_num_);
input_device_tensors_.resize(input_nodes_.size());
auto data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end()) {
for (auto &input_data : data_iter->second) {
@ -218,6 +241,18 @@ void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
}
}
data_iter->second.clear();
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
auto device_tensor =
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
if (device_tensor == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
}
}
void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
@ -237,6 +272,28 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
data->data_ = input_device_tensors_[data_arrow->from_output_index_];
Async(data_arrow->to_op_id_, &OpActor::RunOpData, data.get(), context);
}
// Send result.
auto &output_branch_result_arrow = output_branch_result_arrows_[index];
for (size_t i = 0; i < output_branch_result_arrow.size(); ++i) {
auto &result_arrow = output_branch_result_arrow[i];
MS_EXCEPTION_IF_NULL(result_arrow);
size_t from_index = result_arrow->from_output_index_;
for (const auto &backend_node : front_to_backend_parameter_[from_index]) {
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() ==
input_device_tensors_[from_index]) {
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
result_arrow->to_input_index_, context);
break;
}
}
}
// Send output control.
auto source_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_branch_control_arrows_[index]) {
Async(output_control, &OpActor::RunOpControl, source_aid, context);
}
}
void SwitchActor::EraseInput(OpContext<DeviceTensor> *context) {
@ -262,5 +319,37 @@ void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context);
}
void SwitchActor::FetchInputNode(const std::vector<AnfNodePtr> &origin_parameters_order,
const FrontToBackendNodeWithContext &front_to_backend_parameters,
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel) {
front_to_backend_parameter_.resize(input_nodes_.size());
for (size_t i = 0; i < input_nodes_.size(); ++i) {
const auto &input_node = input_nodes_[i];
if (input_node->isa<ValueNode>()) {
front_to_backend_parameter_[i].push_back({input_node, 0});
} else if (input_node->isa<Parameter>()) {
if (front_to_backend_parameters.find(input_node) != front_to_backend_parameters.end()) {
const auto backend_node = front_to_backend_parameters.at(input_node).first;
front_to_backend_parameter_[i].push_back({backend_node, 0});
}
} else if (input_node->isa<CNode>()) {
if (IsCallNode(input_node)) {
const auto func_graphs = FetchFuncGraphbyCallNode(input_node->cast<CNodePtr>());
for (const auto func_graph : func_graphs) {
if (func_graph->output()->isa<ValueNode>()) {
front_to_backend_parameter_[i].push_back({func_graph->output(), 0});
}
}
} else {
const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
if (front_to_backend_kernel.find(input_node) != front_to_backend_kernel.end()) {
front_to_backend_parameter_[i].emplace_back(kernel_with_index);
}
}
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include <unordered_map>
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/device_tensor_store.h"
#include "runtime/framework/control_node_parser.h"
#include "mindrt/include/actor/switch_actor.h"
#include "runtime/hardware/device_context.h"
@ -56,7 +57,8 @@ constexpr size_t kMakeTupleInputStartPos = 1;
// 5. Free Memory
class SwitchActor : public SwitchActorBase<DeviceTensor> {
public:
SwitchActor(const std::string &name, const CNodePtr &node) : SwitchActorBase(name), node_(node) {}
SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node)
: SwitchActorBase(name), device_context_(device_context), node_(node) {}
~SwitchActor() override = default;
void Init() override;
@ -91,19 +93,35 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
void EraseInput(OpContext<DeviceTensor> *context);
void SendMemoryFreeReq(OpContext<DeviceTensor> *context);
// Collect all the backend inputs of switch actor.
void FetchInputNode(const std::vector<AnfNodePtr> &origin_parameters_order,
const FrontToBackendNodeWithContext &front_to_backend_parameters,
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel);
// All inputs of the switch actor, excluding weight and tensor.
// Used to receive input data, the first input is the condition of switch.
std::vector<AnfNodePtr> input_nodes_;
// The position of the branch output in the input_nodes_.
std::vector<std::vector<size_t>> branch_inputs_pos_;
// Control arrows of different branches.
std::vector<std::vector<AID>> output_branch_control_arrows_;
// Result arrows of different branches.
std::vector<std::vector<DataArrowPtr>> output_branch_result_arrows_;
// When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor,
// so all the nodes that may send the device tensor to switch actor are recorded.
std::vector<std::vector<KernelWithIndex>> front_to_backend_parameter_;
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
std::vector<FuncGraphPtr> branch_func_graph_;
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
std::vector<DeviceTensor *> input_device_tensors_;
// Save the DeviceContext of input_nodes_, which is used to release the DeviceTensor.
DeviceContext *device_context_;
const DeviceContext *device_context_;
// The id of memory manager actor. Send message to it for alloc and free memory.
const AID memory_manager_aid_;

View File

@ -16,6 +16,7 @@
#include "runtime/framework/control_node_parser.h"
#include "runtime/framework/actor/switch_actor.h"
#include "ir/tensor.h"
namespace mindspore {
namespace runtime {
@ -43,6 +44,12 @@ bool CheckValidFuncGraphInput(const AnfNodePtr &node) {
return (!IsPersistentDeviceTensor(node)) && (!HasAbstractMonad(node));
}
// Get the funcgraph in partial node.
FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) {
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
}
// Get the corresponding relationship between funcgraph and parameters in the switch node.
void FetchParameterBySwitchNode(const AnfNodePtr &switch_node, FuncGraphToParameter *graph_to_real_parameters) {
const auto &switch_cnode = switch_node->cast<CNodePtr>();
@ -53,7 +60,7 @@ void FetchParameterBySwitchNode(const AnfNodePtr &switch_node, FuncGraphToParame
for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
const auto &partial_node = switch_inputs[i];
const auto &func_graph = ControlNodeParser::GetFuncGraphFromPartial(partial_node);
const auto &func_graph = GetFuncGraphFromPartial(partial_node);
std::vector<AnfNodePtr> parameters;
const auto &partial_inputs = partial_node->cast<CNodePtr>()->inputs();
for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
@ -78,7 +85,7 @@ void FetchParameterBySwitchLayerNode(const AnfNodePtr &switch_layer_node, const
auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
const auto &func_graph = ControlNodeParser::GetFuncGraphFromPartial(tuple_inputs[i]);
const auto &func_graph = GetFuncGraphFromPartial(tuple_inputs[i]);
std::vector<AnfNodePtr> parameters;
const auto &partial_inputs = tuple_inputs[i]->cast<CNodePtr>()->inputs();
for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
@ -104,9 +111,46 @@ void FetchParameterBySwitchLayerNode(const AnfNodePtr &switch_layer_node, const
}
}
}
// Create a device tensor for the front node.
// Get the output format and select kernel build info from the backend node corresponding to the front node to
// create the device address.
void CreateDeviceTensorForValueNode(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_context);
const auto &node_value = front_node->cast<ValueNodePtr>()->value();
if (!node_value->isa<tensor::Tensor>()) {
return;
}
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(backend_node, 0);
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(backend_node, 0);
if (output_type_id == kTypeUnknown) {
output_type_id = AnfAlgo::GetOutputInferDataType(backend_node, 0);
}
if (front_node->kernel_info() == nullptr) {
front_node->set_kernel_info(std::make_shared<device::KernelInfo>());
}
// Get the select kernel build info.
auto kernel_info = static_cast<device::KernelInfo *>(backend_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
MS_EXCEPTION_IF_NULL(build_info);
AnfAlgo::SetSelectKernelBuildInfo(build_info, front_node.get());
// Create device tensor.
std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
device::DeviceAddressPtr address =
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
MS_EXCEPTION_IF_NULL(address);
AnfAlgo::SetOutputAddr(address, 0, front_node.get());
}
} // namespace
bool ControlNodeParser::IsCallNode(const AnfNodePtr &node) {
bool IsCallNode(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;
}
@ -115,12 +159,7 @@ bool ControlNodeParser::IsCallNode(const AnfNodePtr &node) {
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
}
FuncGraphPtr ControlNodeParser::GetFuncGraphFromPartial(const AnfNodePtr &node) {
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
}
std::vector<FuncGraphPtr> ControlNodeParser::FetchFuncGraphbyCallNode(const CNodePtr &node) {
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const CNodePtr &node) {
std::vector<FuncGraphPtr> func_graphs;
const auto &call_inputs = node->inputs();
@ -155,6 +194,182 @@ std::vector<FuncGraphPtr> ControlNodeParser::FetchFuncGraphbyCallNode(const CNod
return func_graphs;
}
FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node) {
auto front_node = GetFrontNodeByBackendNode(node);
// If the front node is nullptr, we can check its inputs.
if (front_node == nullptr) {
if (node->isa<CNode>()) {
const auto &cnode = node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
const auto &func_graph = FetchFuncGraphByNode(inputs[i]);
if (func_graph != nullptr) {
return func_graph;
}
}
} else {
return nullptr;
}
}
const auto &func_graph = front_node->func_graph();
return func_graph;
}
AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node) {
if (backend_node->func_graph() == nullptr) {
return nullptr;
}
auto kernel_graph = dynamic_cast<KernelGraph *>(backend_node->func_graph().get());
if (kernel_graph == nullptr) {
return nullptr;
}
return kernel_graph->GetFrontAnfByBackendAnf(backend_node);
}
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) {
auto front_node = GetFrontNodeByBackendNode(backend_node);
if (front_node == nullptr) {
return nullptr;
}
return front_node->func_graph();
}
std::vector<AnfNodePtr> FetchInputNodeByParameter(const AnfNodePtr &parameter,
const std::vector<AnfNodePtr> &host_ds_parameters,
std::vector<AnfNodePtr> *invalid_inputs,
const FuncGraphToParameter &graph_to_real_parameters) {
std::vector<AnfNodePtr> input_nodes;
// If the node has been collected, skip it.
if (find((*invalid_inputs).begin(), (*invalid_inputs).end(), parameter) != (*invalid_inputs).end()) {
return input_nodes;
}
// Record the node which has been collected.
(*invalid_inputs).emplace_back(parameter);
// If the parameter node is a parameter of host data source actor, return it.
if (find(host_ds_parameters.begin(), host_ds_parameters.end(), parameter) != host_ds_parameters.end()) {
input_nodes.emplace_back(parameter);
return input_nodes;
}
// Check the parameter which send to its funcgraph.
const auto &func_graph = parameter->func_graph();
if (graph_to_real_parameters.find(func_graph) == graph_to_real_parameters.end()) {
return input_nodes;
}
std::vector<AnfNodePtr> self_inputs;
for (const auto &input : func_graph->get_inputs()) {
// Monad input need not send to funcgraph.
if (!HasAbstractMonad(input)) {
self_inputs.emplace_back(input);
}
}
size_t pos = find(self_inputs.begin(), self_inputs.end(), parameter) - self_inputs.begin();
for (const auto parameters : graph_to_real_parameters.at(func_graph)) {
const auto input = parameters[pos];
if (input->isa<CNode>()) {
input_nodes.emplace_back(input);
} else if (input->isa<Parameter>()) {
// If input is a parameter, you need to find its input recursively.
auto inputs = FetchInputNodeByParameter(input, host_ds_parameters, invalid_inputs, graph_to_real_parameters);
input_nodes.insert(input_nodes.end(), inputs.begin(), inputs.end());
}
}
return input_nodes;
}
void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph) {
FetchFrontToBackendParameterMap(graphs, device_contexts, control_nodes);
FetchFuncGraphToParameterMap(control_nodes);
FetchHostParameterToWeightMap(control_nodes);
FetchFrontValueNode(graphs, device_contexts);
// Get inputs of control node which come from the host actor.
control_node_parameters_ = FetchControlNodeParameter(control_nodes);
front_output_nodes_ = FetchAllBranchOutputs(root_graph);
}
void ControlNodeParser::FetchValueNodeInSwitchNode(const AnfNodePtr &switch_node,
std::vector<AnfNodePtr> *value_nodes) {
const auto &cnode = switch_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (inputs.size() != kSwitchInputNum) {
MS_LOG(EXCEPTION) << "Invalid switch node input num:" << inputs.size();
}
for (const auto &input : inputs) {
if (input->isa<ValueNode>()) {
const auto &node_value = input->cast<ValueNodePtr>()->value();
if (node_value->isa<tensor::Tensor>()) {
(*value_nodes).emplace_back(input);
}
} else if (IsCallNode(input)) {
// If input is a call not, should check the switch node in its input.
const auto &call_node = input->cast<CNodePtr>();
const auto &call_inputs = call_node->inputs();
if (call_inputs.empty() || (!AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch))) {
continue;
}
FetchValueNodeInSwitchNode(call_inputs[0], value_nodes);
} else if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimPartial)) {
const auto &partial_node = input->cast<CNodePtr>();
const auto &partial_inputs = partial_node->inputs();
if (partial_inputs.size() <= kPartialFuncGraphPos) {
MS_LOG(EXCEPTION) << "Invalid partial node input num:" << partial_inputs.size();
}
// if input is a partial node, get the value node in its funcgraph.
const auto &func_graph = GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
if (func_graph->output()->isa<ValueNode>()) {
(*value_nodes).emplace_back(func_graph->output());
}
}
}
}
void ControlNodeParser::FetchFrontValueNode(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts) {
for (size_t index = 0; index < graphs.size(); ++index) {
const auto &graph = graphs[index];
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (const auto &parameter : graph->input_nodes()) {
const auto &front_node = graph->GetFrontAnfByBackendAnf(parameter);
const auto &internal_node = graph->GetFrontNodeByInternalParameter(parameter);
MS_EXCEPTION_IF_NULL(parameter);
if (IsInternalParameter(parameter, graph)) {
auto front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter);
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
const auto &front_output_with_index =
AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false);
auto front_output_node = front_output_with_index.first;
MS_EXCEPTION_IF_NULL(front_output_node);
if (AnfAlgo::CheckPrimitiveType(front_output_node, prim::kPrimSwitch)) {
std::vector<AnfNodePtr> value_nodes;
FetchValueNodeInSwitchNode(front_output_node, &value_nodes);
for (const auto value_node : value_nodes) {
CreateDeviceTensorForValueNode(value_node, parameter, device_contexts[index]);
front_value_nodes_.push_back({value_node, device_contexts[index]});
}
}
}
}
}
}
AnfNodePtr ControlNodeParser::GetCallNodeInputByPos(const AnfNodePtr &call_node, const FuncGraphPtr &func_graph,
const size_t pos) {
const auto &cnode = call_node->cast<CNodePtr>();
@ -401,8 +616,7 @@ std::vector<AnfNodePtr> ControlNodeParser::FetchAllBranchOutputs(const FuncGraph
void ControlNodeParser::FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes,
FrontToBackendNodeWithContext *front_to_backend_parameter) {
const std::vector<AnfNodePtr> &control_nodes) {
if (graphs.size() != device_contexts.size()) {
MS_LOG(EXCEPTION) << "Graph num is not equal to device context num.";
}
@ -414,8 +628,8 @@ void ControlNodeParser::FetchFrontToBackendParameterMap(const std::vector<Kernel
for (const auto &parameter : graph->parameters()) {
auto front_node = graph->GetFrontAnfByBackendAnf(parameter);
if (front_node != nullptr && front_node->isa<Parameter>() &&
(*front_to_backend_parameter).find(front_node) == (*front_to_backend_parameter).end()) {
(*front_to_backend_parameter)[front_node] = {parameter, device_context};
front_to_backend_parameters_.find(front_node) == front_to_backend_parameters_.end()) {
front_to_backend_parameters_[front_node] = {parameter, device_context};
}
}
}
@ -431,15 +645,14 @@ void ControlNodeParser::FetchFrontToBackendParameterMap(const std::vector<Kernel
for (const auto &front_pair : front_to_front_parameter) {
std::set<AnfNodePtr> invalid_node;
const auto &backend_node = FetchBackendNodeByFrontNode(front_pair.first, front_to_front_parameter,
*front_to_backend_parameter, &invalid_node);
front_to_backend_parameters_, &invalid_node);
if (backend_node.first != nullptr) {
(*front_to_backend_parameter)[front_pair.first] = backend_node;
front_to_backend_parameters_[front_pair.first] = backend_node;
}
}
}
void ControlNodeParser::FetchHostParameterToWeightMap(const std::vector<AnfNodePtr> &control_nodes,
HostParameterToWeight *host_parameter_to_weights) {
void ControlNodeParser::FetchHostParameterToWeightMap(const std::vector<AnfNodePtr> &control_nodes) {
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> front_to_front_parameter;
FetchFrontToFrontParameterMap(control_nodes, &front_to_front_parameter);
@ -447,54 +660,11 @@ void ControlNodeParser::FetchHostParameterToWeightMap(const std::vector<AnfNodeP
for (const auto &pair : front_to_front_parameter) {
std::vector<AnfNodePtr> dest_nodes;
FetchWeightbyHostParameter(pair.first, &dest_nodes, front_to_front_parameter);
(*host_parameter_to_weights)[pair.first] = dest_nodes;
host_parameter_to_weights_[pair.first] = dest_nodes;
}
}
FuncGraphPtr ControlNodeParser::FetchFuncGraphByNode(const AnfNodePtr &node) {
auto front_node = GetFrontNodeByBackendNode(node);
// If the front node is nullptr, we can check its inputs.
if (front_node == nullptr) {
if (node->isa<CNode>()) {
const auto &cnode = node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
const auto &func_graph = FetchFuncGraphByNode(inputs[i]);
if (func_graph != nullptr) {
return func_graph;
}
}
} else {
return nullptr;
}
}
const auto &func_graph = front_node->func_graph();
return func_graph;
}
AnfNodePtr ControlNodeParser::GetFrontNodeByBackendNode(const AnfNodePtr &backend_node) {
if (backend_node->func_graph() == nullptr) {
return nullptr;
}
auto kernel_graph = dynamic_cast<KernelGraph *>(backend_node->func_graph().get());
if (kernel_graph == nullptr) {
return nullptr;
}
return kernel_graph->GetFrontAnfByBackendAnf(backend_node);
}
FuncGraphPtr ControlNodeParser::GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) {
auto front_node = GetFrontNodeByBackendNode(backend_node);
if (front_node == nullptr) {
return nullptr;
}
return front_node->func_graph();
}
void ControlNodeParser::FetchFuncGraphToParameterMap(const std::vector<AnfNodePtr> &control_nodes,
FuncGraphToParameter *graph_to_real_parameters) {
void ControlNodeParser::FetchFuncGraphToParameterMap(const std::vector<AnfNodePtr> &control_nodes) {
for (const auto &control_node : control_nodes) {
const auto &cnode = control_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
@ -508,10 +678,10 @@ void ControlNodeParser::FetchFuncGraphToParameterMap(const std::vector<AnfNodePt
if (AnfAlgo::CheckPrimitiveType(switch_cnode, prim::kPrimSwitch)) {
// Switch node.
FetchParameterBySwitchNode(inputs[0], graph_to_real_parameters);
FetchParameterBySwitchNode(inputs[0], &func_graph_to_parameters_);
} else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
// Switchlayer node.
FetchParameterBySwitchLayerNode(inputs[0], inputs, graph_to_real_parameters);
FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_);
} else {
MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString();
}
@ -524,57 +694,9 @@ void ControlNodeParser::FetchFuncGraphToParameterMap(const std::vector<AnfNodePt
parameters.emplace_back(inputs[i]);
}
}
(*graph_to_real_parameters)[func_graph].emplace_back(parameters);
func_graph_to_parameters_[func_graph].emplace_back(parameters);
}
}
}
std::vector<AnfNodePtr> ControlNodeParser::FetchInputNodeByParameter(
const AnfNodePtr &parameter, const std::vector<AnfNodePtr> &host_ds_parameters,
std::vector<AnfNodePtr> *invalid_inputs,
const std::unordered_map<FuncGraphPtr, std::vector<std::vector<AnfNodePtr>>> &graph_to_real_parameters) {
std::vector<AnfNodePtr> input_nodes;
// If the node has been collected, skip it.
if (find((*invalid_inputs).begin(), (*invalid_inputs).end(), parameter) != (*invalid_inputs).end()) {
return input_nodes;
}
// Record the node which has been collected.
(*invalid_inputs).emplace_back(parameter);
// If the parameter node is a parameter of host data source actor, return it.
if (find(host_ds_parameters.begin(), host_ds_parameters.end(), parameter) != host_ds_parameters.end()) {
input_nodes.emplace_back(parameter);
return input_nodes;
}
// Check the parameter which send to its funcgraph.
const auto &func_graph = parameter->func_graph();
if (graph_to_real_parameters.find(func_graph) == graph_to_real_parameters.end()) {
return input_nodes;
}
std::vector<AnfNodePtr> self_inputs;
for (const auto &input : func_graph->get_inputs()) {
// Monad input need not send to funcgraph.
if (!HasAbstractMonad(input)) {
self_inputs.emplace_back(input);
}
}
size_t pos = find(self_inputs.begin(), self_inputs.end(), parameter) - self_inputs.begin();
for (const auto parameters : graph_to_real_parameters.at(func_graph)) {
const auto input = parameters[pos];
if (input->isa<CNode>()) {
input_nodes.emplace_back(input);
} else if (input->isa<Parameter>()) {
// If input is a parameter, you need to find its input recursively.
auto inputs = FetchInputNodeByParameter(input, host_ds_parameters, invalid_inputs, graph_to_real_parameters);
input_nodes.insert(input_nodes.end(), inputs.begin(), inputs.end());
}
}
return input_nodes;
}
} // namespace runtime
} // namespace mindspore

View File

@ -40,71 +40,86 @@ constexpr int kSubBranchStartID = 1;
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
using FuncGraphToParameter = std::unordered_map<FuncGraphPtr, std::vector<std::vector<AnfNodePtr>>>;
using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
// ControlNodeParser is a series of tool functions used to parse control nodes.
using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
// Check whether node is a call node, there are two types of call nodes:
// 1. First input of node is a cnode.
// 2. First input of node is a funcgraph value node.
bool IsCallNode(const AnfNodePtr &node);
FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node);
// Get front node by backend node.
AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node);
// Get the funcgraph to which the node belongs.
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
// Find all funcgraphs that the call node will call.
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const CNodePtr &node);
// Fetch all backend input nodes by parameter for gather actor.
std::vector<AnfNodePtr> FetchInputNodeByParameter(const AnfNodePtr &parameter,
const std::vector<AnfNodePtr> &host_ds_parameters,
std::vector<AnfNodePtr> *invalid_inputs,
const FuncGraphToParameter &graph_to_real_parameters);
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
class ControlNodeParser {
public:
// Check whether node is a call node, there are two types of call nodes:
// 1. First input of node is a cnode.
// 2. First input of node is a funcgraph value node.
static bool IsCallNode(const AnfNodePtr &node);
// Parse the control node and put the results of the parsing into member variables.
void Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph);
std::vector<AnfNodePtr> GetControlNodeParameter() { return control_node_parameters_; }
// Get the output of funcgraph, usually there is only one output node, In the control flow, there are
// multiple branch outputs, there will be multiple output nodes.
std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph);
private:
friend class GraphScheduler;
// Collect all front value nodes. In the control flow, when the input of the switch actor is the value node, these
// value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
// them separately during initialization.
// The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
void FetchFrontValueNode(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts);
// Find all value nodes in the switch recursively.
void FetchValueNodeInSwitchNode(const AnfNodePtr &switch_node, std::vector<AnfNodePtr> *value_nodes);
// Fetch all the relationships between front parameters and backend parameters.The front parameters
// include two parts:
// 1. The parameter from kernel graph.
// 2. The parameter from control nodes.
static void FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes,
FrontToBackendNodeWithContext *front_to_backend_parameter);
void FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes);
// Get inputs of control node which come from the host actor. These inputs generally come from the partial
// nodes and call nodes of the root funcgraph.
static std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
// Get the output of funcgraph, usually there is only one output node, In the control flow, there are
// multiple branch outputs, there will be multiple output nodes.
static std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph);
// Find all funcgraphs that the call node will call.
static std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const CNodePtr &node);
// Get the funcgraph to which the node belongs.
static FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
static FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node);
// Get front node by backend node.
static AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node);
// Get the funcgraph in partial node.
static FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node);
std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
// Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node,
// and the input of the call node is the input parameter of the corresponding funcgraph.
static void FetchFuncGraphToParameterMap(const std::vector<AnfNodePtr> &control_nodes,
FuncGraphToParameter *graph_to_real_parameters);
void FetchFuncGraphToParameterMap(const std::vector<AnfNodePtr> &control_nodes);
// Get all the front weight parameters related to the weight in the host parameter.
static void FetchHostParameterToWeightMap(const std::vector<AnfNodePtr> &control_nodes,
HostParameterToWeight *host_parameter_to_weights);
void FetchHostParameterToWeightMap(const std::vector<AnfNodePtr> &control_nodes);
// Fetch all backend input nodes by parameter for gather actor.
static std::vector<AnfNodePtr> FetchInputNodeByParameter(const AnfNodePtr &parameter,
const std::vector<AnfNodePtr> &host_ds_parameters,
std::vector<AnfNodePtr> *invalid_inputs,
const FuncGraphToParameter &graph_to_real_parameters);
private:
// Get the pos input of call node to funcgraph.
AnfNodePtr GetCallNodeInputByPos(const AnfNodePtr &call_node, const FuncGraphPtr &func_graph, const size_t pos);
// Find the output of the funcgraph, if the output is a call node, return the output of the funcgraph
// called by the call node.
static std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph,
std::vector<AnfNodePtr> *call_nodes);
std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *call_nodes);
// Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding
// backend parameter, then recursively find the backend parameters of other front parameters corresponding to the
// front_node.
static std::pair<AnfNodePtr, DeviceContext *> FetchBackendNodeByFrontNode(
std::pair<AnfNodePtr, DeviceContext *> FetchBackendNodeByFrontNode(
const AnfNodePtr &front_node,
const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &front_to_front_parameter,
const std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>> &front_to_backend_parameter,
@ -115,10 +130,27 @@ class ControlNodeParser {
// 1. The parameter is used as the input of the call node,
// 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the
// subsequent call node.
static void FetchFrontToFrontParameterMap(
const std::vector<AnfNodePtr> &control_nodes,
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
void FetchFrontToFrontParameterMap(const std::vector<AnfNodePtr> &control_nodes,
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
FrontToBackendNodeWithContext front_to_backend_parameters_;
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
// the input node of gather.
FuncGraphToParameter func_graph_to_parameters_;
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
// When initializing the weights, all related weights need to be recorded as the same device tensor.
HostParameterToWeight host_parameter_to_weights_;
// The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the
// input of the control node.
NodeWithDeviceContext front_value_nodes_;
// The front output_node is used to link the output actor in multi-branch output scenario.
std::vector<AnfNodePtr> front_output_nodes_;
// Parameters of control node which come from the host actor.
std::vector<AnfNodePtr> control_node_parameters_;
};
using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>;
} // namespace runtime
} // namespace mindspore

View File

@ -496,30 +496,9 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
}
}
// 3.Fill host tensors for non weighted parameters which belongs to control node.
std::vector<AnfNodePtr> control_node_parameters =
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
const auto &tensors = input_tensors.back();
const auto &parameters = graph_compiler_info.origin_parameters_order_;
for (size_t j = 0; j < control_node_parameters.size(); ++j) {
const auto &input_node = control_node_parameters[j];
const auto &input_tensor = tensors[j];
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
const auto &iter = graph_compiler_info.front_to_backend_parameters_.find(input_node);
if (iter == graph_compiler_info.front_to_backend_parameters_.end()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
<< AnfAlgo::GetNodeDebugString(input_node);
}
const auto &node_with_context = iter->second;
PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second,
graph_compiler_info.host_parameter_to_weights_);
} else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) {
MS_EXCEPTION_IF_NULL(host_data_source_actor);
PrepareDataForHostDataSourceActor(host_data_source_actor->data_node_position_map_, input_node, input_tensor,
&host_tensors);
}
}
// 3.Prepare the data which belongs to control node.
PrepareDataForControlNode(graph_compiler_info.control_node_parser_, graph_compiler_info.origin_parameters_order_,
input_tensors.back(), host_data_source_actor->data_node_position_map_, &host_tensors);
// 4.Prepare the data of host tensor queue(non weighted parameters of graph).
if (host_data_source_actor != nullptr) {
@ -529,6 +508,39 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
}
}
void GraphScheduler::PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser,
const std::vector<AnfNodePtr> &origin_parameters,
const std::vector<TensorPtr> &tensors,
const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
std::vector<TensorPtr> *host_tensors) {
const auto &control_node_parameters = control_node_parser->GetControlNodeParameter();
for (size_t j = 0; j < control_node_parameters.size(); ++j) {
const auto &input_node = control_node_parameters[j];
const auto &input_tensor = tensors[j];
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters_;
const auto &iter = front_to_backend_parameters.find(input_node);
if (iter == front_to_backend_parameters.end()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
<< AnfAlgo::GetNodeDebugString(input_node);
}
const auto &node_with_context = iter->second;
PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second,
control_node_parser->host_parameter_to_weights_);
} else if (find(origin_parameters.begin(), origin_parameters.end(), input_node) != origin_parameters.end()) {
PrepareDataForHostDataSourceActor(data_node_position_map, input_node, input_tensor, host_tensors);
}
}
for (const auto &value_node_with_context : control_node_parser->front_value_nodes_) {
if (AnfAlgo::OutputAddrExist(value_node_with_context.first, 0)) {
PrepareDataForValueNode(value_node_with_context.first->cast<ValueNodePtr>(), value_node_with_context.second);
}
}
}
bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strategy,
const std::vector<TensorPtr> *input_tensors) {
MS_EXCEPTION_IF_NULL(actor_set);
@ -773,12 +785,11 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
}
}
const auto &front_to_backend_parameter = graph_compiler_info.front_to_backend_parameters_;
const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
// Initialize the parameter in the control node, first get all the front parameters in the control node, then find
// the corresponding backend parameter from the map, and insert it into the host data source actor
std::vector<AnfNodePtr> control_node_parameters =
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
std::vector<AnfNodePtr> control_node_parameters = graph_compiler_info.control_node_parser_->GetControlNodeParameter();
for (const auto parameter : control_node_parameters) {
if (IsPersistentDeviceTensor(parameter)) {
continue;
@ -926,13 +937,23 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const GraphC
std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
std::vector<SwitchActorPtr> switch_actors;
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
for (const auto &pair : front_node_to_actor_) {
front_to_backend_kernel[pair.first] = pair.second->kernel_;
}
for (const auto &control_node : graph_compiler_info.control_nodes_) {
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
auto actor_name = control_node->fullname_with_scope();
auto switch_actor = std::make_shared<SwitchActor>(actor_name, control_node->cast<CNodePtr>());
auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
control_node->cast<CNodePtr>());
switch_actor->Initialize();
// Fetch all the input nodes of switch actor.
switch_actor->FetchInputNode(graph_compiler_info.origin_parameters_order_,
graph_compiler_info.control_node_parser_->front_to_backend_parameters_,
front_to_backend_kernel);
InsertActor(switch_actor.get());
switch_actors.emplace_back(switch_actor);
}
@ -958,6 +979,13 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
continue;
}
const auto &cnode = control_node->cast<CNodePtr>();
const auto inputs = cnode->inputs();
// If the output of funcgraph is a value node, no need to create gather actor.
if (inputs[kReturnInputPos]->isa<ValueNode>()) {
continue;
}
auto func_graph = control_node->func_graph();
auto actor_name = func_graph->ToString();
std::vector<AnfNodePtr> parameters;
@ -977,8 +1005,9 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
auto gather_actor =
std::make_shared<GatherActor>(actor_name, parameters, loop_count_actor->GetAID(), output_actor->GetAID());
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.origin_parameters_order_,
graph_compiler_info.front_to_backend_parameters_,
graph_compiler_info.func_graph_to_parameters_, front_to_backend_kernel);
graph_compiler_info.control_node_parser_->front_to_backend_parameters_,
graph_compiler_info.control_node_parser_->func_graph_to_parameters_,
front_to_backend_kernel);
InsertActor(gather_actor.get());
gather_actors.emplace_back(gather_actor);
}
@ -994,7 +1023,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
MS_EXCEPTION_IF_NULL(graph);
auto from_kernel = from_kernel_with_output_idx.first;
auto front_node = ControlNodeParser::GetFrontNodeByBackendNode(from_kernel);
auto front_node = GetFrontNodeByBackendNode(from_kernel);
if (IsDeviceQueueDSActor(from_kernel)) {
// Link the data arrows of device queue data source actor.
std::string actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
@ -1002,7 +1031,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (front_node != nullptr && IsGatherActor(front_node, actor_name_to_actor_)) {
// Link the data arrows of gather actor.
auto func_graph = ControlNodeParser::GetFuncgraphByBackendNode(from_kernel);
auto func_graph = GetFuncgraphByBackendNode(from_kernel);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find funcgraph of node:" << AnfAlgo::GetNodeDebugString(from_kernel);
}
@ -1064,8 +1093,11 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_kernel_, actor_pair.second);
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsSwitchActor(front_output_node)) {
const auto &from_actor = dynamic_cast<SwitchActor *>(FetchActor(front_output_node->fullname_with_scope()));
MS_LOG(ERROR) << "Need link to switch actor:" << from_actor->GetAID();
const auto &actor_name = front_output_node->fullname_with_scope();
const auto &actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
LinkDataArrowForSwitchActor(switch_actor, to_actor, to_kernel_with_input_idx.second);
} else if (IsKernelActor(front_output_node)) {
auto from_actor = dynamic_cast<KernelActor *>(actor_pair.first);
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->kernel_, actor_pair.second);
@ -1512,6 +1544,46 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
}
}
void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
const ActorSet *actor_set) {
const auto &to_actor = actor_set->output_actor_;
const auto &loop_count_actor = actor_set->loop_count_actor_;
const auto &switch_actors = actor_set->switch_actors_;
if (to_actor == nullptr || loop_count_actor == nullptr) {
return;
}
for (const auto &from_actor : switch_actors) {
MS_EXCEPTION_IF_NULL(from_actor);
auto origin_output_with_index = KernelWithIndex(from_actor->node_, 0);
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
continue;
}
// If the switch actor is in the output list, the output of switch actor should be sent to the output actor.
// And need to link a control arrow to the loop count actor.
for (const auto pos : iter->second.second) {
to_actor->device_contexts_[pos] = from_actor->device_context_;
}
for (size_t i = 0; i < from_actor->branch_inputs_pos_.size(); ++i) {
const auto &input_pos = from_actor->branch_inputs_pos_[i];
if (input_pos.empty()) {
MS_LOG(EXCEPTION) << "Invalid input num in switch actor:" << from_actor->GetAID();
}
for (const auto pos : iter->second.second) {
auto op_arrow = std::make_shared<DataArrow>(input_pos[0], to_actor->GetAID(), pos);
from_actor->output_branch_result_arrows_[i].emplace_back(op_arrow);
}
from_actor->output_branch_control_arrows_[i].emplace_back(loop_count_actor->GetAID());
}
loop_count_actor->branch_id_to_input_controls_num_[kMainBranchID]++;
}
}
void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors) {
const size_t kNeedUpdateDeviceTensorStoreNum = 2;
for (auto &kernel_actor : auto_monad_actors) {
@ -1606,9 +1678,12 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
LinkBranchArrowForGatherActor(graph_compiler_info, actor_set);
LinkControlArrowForGatherActor(graph_compiler_info, actor_set);
LinkControlArrowForGatherActor(&(actor_set->gather_actors_), actor_set->loop_count_actor_.get(),
graph_compiler_info.graphs_);
LinkOutputResultArrowForGatherActor(graph_compiler_info, actor_set);
LinkOutputResultArrowForSwitchActor(graph_compiler_info, actor_set);
}
void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor,
@ -1621,7 +1696,7 @@ void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, Kernel
MS_EXCEPTION_IF_NULL(from_kernel);
auto to_input_index = to_kernel_with_input_idx.second;
auto front_node = ControlNodeParser::GetFrontNodeByBackendNode(from_kernel);
auto front_node = GetFrontNodeByBackendNode(from_kernel);
if (front_node == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find front node of node:" << AnfAlgo::GetNodeDebugString(from_kernel);
}
@ -1637,7 +1712,8 @@ void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, Kernel
void GraphScheduler::LinkDataArrowByCallInput(const GraphCompilerInfo &graph_compiler_info, const AnfNodePtr &call_node,
OpActor<DeviceTensor> *to_actor, const size_t to_index) {
// Fetch all the funcgraph that call node would call.
std::vector<FuncGraphPtr> func_graphs = ControlNodeParser::FetchFuncGraphbyCallNode(call_node->cast<CNodePtr>());
const auto cnode = call_node->cast<CNodePtr>();
std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(cnode);
// Collect the output of each funcgraph.
for (const auto &func_graph : func_graphs) {
@ -1677,19 +1753,58 @@ void GraphScheduler::LinkDataArrowByCallInput(const GraphCompilerInfo &graph_com
const size_t pos = iter - gather_actor->data_nodes_.begin();
auto op_arrow = std::make_shared<DataArrow>(pos, to_actor->GetAID(), to_index);
gather_actor->output_data_arrows_.emplace_back(op_arrow);
} else if (output_with_index.first->isa<ValueNode>()) {
// If the output is a value node, then the value node needs to be sent by the switch actor.
const auto &call_inputs = cnode->inputs();
if (AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch)) {
const auto &actor_name = call_inputs[0]->fullname_with_scope();
const auto &actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
// Add output for each branch of switch.
for (size_t i = 0; i < switch_actor->branch_inputs_pos_.size(); ++i) {
if (switch_actor->branch_inputs_pos_[i].empty()) {
MS_LOG(EXCEPTION) << "No input for switch actor:" << actor_name << " branch:" << i;
}
const auto from_index = switch_actor->branch_inputs_pos_[i][0];
auto op_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
switch_actor->output_branch_arrows_[i].emplace_back(op_arrow);
}
} else {
MS_LOG(EXCEPTION) << "Invalid input for call node:" << AnfAlgo::GetNodeDebugString(call_node);
}
} else {
MS_LOG(EXCEPTION) << "Output of func graph is not a parameter or kernel, func graph:" << func_graph->ToString();
MS_LOG(EXCEPTION) << "Output of func graph is not a parameter or kernel, func graph:" << func_graph->ToString()
<< " output node:" << AnfAlgo::GetNodeDebugString(output_with_index.first);
}
}
}
void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, KernelActor *to_actor,
const size_t to_index) {
MS_EXCEPTION_IF_NULL(from_actor);
for (size_t i = 0; i < from_actor->output_branch_arrows_.size(); ++i) {
if (from_actor->branch_inputs_pos_[i].empty()) {
MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i;
}
const auto from_index = from_actor->branch_inputs_pos_[i][0];
auto op_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
from_actor->output_branch_arrows_[i].emplace_back(op_arrow);
}
to_actor->input_datas_num_++;
}
void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info,
const AnfNodePtr &input_node, OpActor<DeviceTensor> *to_actor,
const size_t to_index) {
const auto &parameters = graph_compiler_info.origin_parameters_order_;
const auto &front_to_backend_parameter = graph_compiler_info.front_to_backend_parameters_;
const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
if (ControlNodeParser::IsCallNode(input_node)) {
if (IsCallNode(input_node)) {
// The actor input is a call node.
LinkDataArrowByCallInput(graph_compiler_info, input_node, to_actor, to_index);
} else if (IsGatherActor(input_node, actor_name_to_actor_)) {
@ -1745,20 +1860,19 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_
const auto &inputs = actor->input_nodes_;
for (size_t i = 0; i < inputs.size(); ++i) {
auto input = inputs[i];
if (input->isa<ValueNode>()) {
continue;
}
LinkDataArrowByControlNode(graph_compiler_info, input, actor, i);
}
// Link switch output.
for (size_t i = 0; i < actor->branch_func_graph_.size(); ++i) {
auto func_graph = actor->branch_func_graph_[i];
if (func_graph == nullptr) {
if (func_graph == nullptr || func_graph->output()->isa<ValueNode>()) {
continue;
}
if (func_graph->output()->isa<ValueNode>()) {
actor->AddInput(func_graph->output(), 0);
}
auto gather_name = func_graph->ToString();
if (actor_name_to_actor_.find(gather_name) == actor_name_to_actor_.end()) {
MS_LOG(EXCEPTION) << "Cannot find gather actor for funcgraph:" << gather_name
@ -1773,13 +1887,15 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_
}
}
void GraphScheduler::LinkControlArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info,
const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(actor_set);
void GraphScheduler::LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors, LoopCountActor *to_actor,
const std::vector<KernelGraphPtr> &graphs) {
if (from_actors == nullptr || to_actor == nullptr) {
return;
}
// Link control arrow to kernel actor.
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
const auto &kernel_graph = graph_compiler_info.graphs_[i];
for (size_t i = 0; i < graphs.size(); ++i) {
const auto &kernel_graph = graphs[i];
MS_EXCEPTION_IF_NULL(kernel_graph);
const auto &func_graph = kernel_graph->GetFuncGraph();
if (func_graph == nullptr) {
@ -1807,12 +1923,11 @@ void GraphScheduler::LinkControlArrowForGatherActor(const GraphCompilerInfo &gra
}
// link control arrow to loop count actor.
for (auto &from_actor : actor_set->gather_actors_) {
for (auto &from_actor : *from_actors) {
MS_EXCEPTION_IF_NULL(from_actor);
// If the gather actor has no output, then adds the output control to loop count actor.
if (from_actor->output_data_arrows_.size() == 0 && from_actor->output_control_arrows_.size() == 0) {
const auto &to_actor = actor_set->loop_count_actor_;
auto to_aid = to_actor->GetAID();
from_actor->output_control_arrows_.emplace_back(to_aid);
to_actor->branch_id_to_input_controls_num_[kMainBranchID]++;
@ -1831,7 +1946,7 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
const auto &output_actor = actor_set->output_actor_.get();
// If there is only one branch output, set the branch id of the loop count to 0, no need to send the branch id.
auto outputs = graph_compiler_info.front_output_nodes_;
auto outputs = graph_compiler_info.control_node_parser_->front_output_nodes_;
if (outputs.size() == 1) {
return;
}
@ -1868,7 +1983,7 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
MS_EXCEPTION_IF_NULL(kernel_actor);
if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0)) {
MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
const auto &sub_func_graph = ControlNodeParser::FetchFuncGraphByNode(kernel_actor->kernel_);
const auto &sub_func_graph = FetchFuncGraphByNode(kernel_actor->kernel_);
if (sub_func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Cannot get funcgraph from kernel:" << kernel_actor->kernel_->fullname_with_scope();
}
@ -1896,6 +2011,18 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
gather_actor->branch_id_ = branch_id;
loop_count_actor->branch_id_to_input_controls_num_[branch_id] = graph_to_control_num[sub_func_graph];
}
// If the switch actor is linked to the output actor, it will link a control arrow to the loop count actor,
// and this should be recorded.
for (const auto &from_actor : actor_set->switch_actors_) {
MS_EXCEPTION_IF_NULL(from_actor);
auto origin_output_with_index = KernelWithIndex(from_actor->node_, 0);
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
continue;
}
loop_count_actor->branch_id_to_input_controls_num_[iter->second.first]++;
}
}
void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info,
@ -2061,6 +2188,15 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
}
}
}
// In control flow, there may be some value nodes that is not in the kernel graph and needs to be placed
// in the tensor store separately.
for (const auto &value_node : graph_compiler_info.control_node_parser_->front_value_nodes_) {
MS_EXCEPTION_IF_NULL(value_node.first);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node.first, 0, false);
DeviceTensorStore::GetInstance().Insert(value_node.first.get(), device_tensor);
UpdateRefCount(device_tensor.get(), true);
}
}
HostTensorQueue *GraphScheduler::FetchHostQueue(const ActorInfo &actor_info) const {

View File

@ -67,33 +67,24 @@ enum class GraphExecutionStrategy {
// The tensors mask is used to distinguish input tensor's type.
// The input tensor is used to link graphs in the dynamic build scenario.
// The control node is used to link graphs in the control flow scenario.
// The control node parser is used to parse the edge info in control nodes.
// The origin parameters order is used to correspond to the input args.
// The origin outputs order is used to correspond to the output args.
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
// the input node of gather,
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
// When initializing the weights, all related weights need to be recorded as the same device tensor.
// The front output_node is used to link the output actor in multi-branch output scenario.
struct GraphCompilerInfo {
GraphCompilerInfo(
const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
const std::vector<std::vector<int64_t> *> &tensors_mask, const std::vector<std::vector<TensorPtr> *> &input_tensors,
const std::vector<AnfNodePtr> &control_nodes, const std::vector<AnfNodePtr> &origin_parameters_order,
const KernelMapPosition &origin_outputs_order, const FrontToBackendNodeWithContext &front_to_backend_parameters,
const FuncGraphToParameter &func_graph_to_parameters, const HostParameterToWeight &host_parameter_to_weights,
const std::vector<AnfNodePtr> &front_output_nodes, const size_t outputs_num, const std::string &name)
GraphCompilerInfo(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
const std::vector<std::vector<int64_t> *> &tensors_mask,
const std::vector<std::vector<TensorPtr> *> &input_tensors,
const std::vector<AnfNodePtr> &control_nodes,
const std::vector<AnfNodePtr> &origin_parameters_order, const ControlNodeParserPtr &parser,
const KernelMapPosition &origin_outputs_order, const size_t outputs_num, const std::string &name)
: graphs_(graphs),
device_contexts_(device_contexts),
tensors_mask_(tensors_mask),
input_tensors_(input_tensors),
control_nodes_(control_nodes),
control_node_parser_(parser),
origin_parameters_order_(origin_parameters_order),
origin_outputs_order_(origin_outputs_order),
front_to_backend_parameters_(front_to_backend_parameters),
func_graph_to_parameters_(func_graph_to_parameters),
host_parameter_to_weights_(host_parameter_to_weights),
front_output_nodes_(front_output_nodes),
outputs_num_(outputs_num),
name_(name) {}
std::vector<KernelGraphPtr> graphs_;
@ -101,12 +92,9 @@ struct GraphCompilerInfo {
std::vector<std::vector<int64_t> *> tensors_mask_;
std::vector<std::vector<TensorPtr> *> input_tensors_;
std::vector<AnfNodePtr> control_nodes_;
ControlNodeParserPtr control_node_parser_;
std::vector<AnfNodePtr> origin_parameters_order_;
KernelMapPosition origin_outputs_order_;
FrontToBackendNodeWithContext front_to_backend_parameters_;
FuncGraphToParameter func_graph_to_parameters_;
HostParameterToWeight host_parameter_to_weights_;
std::vector<AnfNodePtr> front_output_nodes_;
size_t outputs_num_;
std::string name_;
};
@ -245,17 +233,27 @@ class GraphScheduler {
// connected.
void LinkDataArrowByCallInput(const GraphCompilerInfo &graph_compiler_info, const AnfNodePtr &call_node,
OpActor<DeviceTensor> *to_actor, const size_t to_index);
void LinkControlArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void LinkDataArrowForSwitchActor(SwitchActor *from_actor, KernelActor *to_actor, const size_t to_index);
void LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors, LoopCountActor *to_actor,
const std::vector<KernelGraphPtr> &graphs);
// In control flow, there are scenarios where there are multi-branch outputs, and the gather actor needs to
// send the branch id to the loop count actor.
void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
// The processing of actors link dynamically.
// Analyze necessary input data of current actor, generate and cache op arrow
// between current actor and prev actor, the method executes before calling Schedule.
void PrepareForDynamiclyLink(ActorSet *actor_set, const CNodePtr &kernel, const AID &aid,
const std::vector<TensorPtr> *input_tensors);
void PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser,
const std::vector<AnfNodePtr> &origin_parameters,
const std::vector<TensorPtr> &tensors,
const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
std::vector<TensorPtr> *host_tensors);
// Link to prev actor dynamically, and send message to prev actor to add the
// new DataArrow and send output data back, the method must execute after calling Schedule.
void LinkDataArrowForKernelActorDynamicly(const ActorSet *actor_set);

View File

@ -454,8 +454,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
std::vector<tensor::TensorPtr> input_tensor;
// Get inputs of control node which come from the host actor.
std::vector<AnfNodePtr> control_node_parameters =
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->GetControlNodeParameter();
for (const auto &parameter : control_node_parameters) {
PushTensor(args, origin_parameters, parameter, &input_tensor);
}
@ -555,10 +554,13 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
name.append("_").append(std::to_string(graph_id_to_context.first));
}
auto parser = std::make_shared<ControlNodeParser>();
parser->Parse(control_nodes_, graphs, device_contexts, root_graph);
// Get all the outputs. In control flow, there may be multiple branch output.
runtime::KernelMapPosition outputs_order;
size_t outputs_num = 0;
const auto &all_branch_output = ControlNodeParser::FetchAllBranchOutputs(root_graph);
const auto &all_branch_output = parser->FetchAllBranchOutputs(root_graph);
for (int j = 0; j < SizeToInt(all_branch_output.size()); ++j) {
// In general, there is only one output branch, and the branch id is 0 at this time. In the control flow,
// there are multi-branch output scenarios. Different branches may have different weight nodes. When output
@ -578,28 +580,10 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
}
}
// Fetch all the relationships between front parameters and backend parameters which will be used to
// build and link actors.
FrontToBackendNodeWithContext front_to_backend_parameter;
ControlNodeParser::FetchFrontToBackendParameterMap(graphs, device_contexts, control_nodes_,
&front_to_backend_parameter);
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
// the input node of gather,
FuncGraphToParameter func_graph_to_parameters;
ControlNodeParser::FetchFuncGraphToParameterMap(control_nodes_, &func_graph_to_parameters);
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
// When initializing the weights, all related weights need to be recorded as the same device tensor.
HostParameterToWeight host_parameter_to_weights;
ControlNodeParser::FetchHostParameterToWeightMap(control_nodes_, &host_parameter_to_weights);
std::vector<std::vector<int64_t> *> tensors_mask;
std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask, input_tensors, control_nodes_,
root_graph->parameters(), outputs_order, front_to_backend_parameter,
func_graph_to_parameters, host_parameter_to_weights, all_branch_output,
outputs_num, name);
root_graph->parameters(), parser, outputs_order, outputs_num, name);
}
std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
@ -629,10 +613,10 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
std::vector<std::vector<int64_t> *> tensors_mask_list(1, const_cast<std::vector<int64_t> *>(tensors_mask));
std::vector<std::vector<TensorPtr> *> input_tensors_list(1,
const_cast<std::vector<tensor::TensorPtr> *>(input_tensors));
return std::make_unique<GraphCompilerInfo>(
graphs, device_contexts, tensors_mask_list, input_tensors_list, std::vector<AnfNodePtr>(),
std::vector<AnfNodePtr>(), outputs_order, FrontToBackendNodeWithContext(), FuncGraphToParameter(),
HostParameterToWeight(), std::vector<AnfNodePtr>(), outputs_order.size(), actor_info);
auto parser = std::make_shared<ControlNodeParser>();
return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask_list, input_tensors_list,
std::vector<AnfNodePtr>(), std::vector<AnfNodePtr>(), parser,
outputs_order, outputs_order.size(), actor_info);
}
void MindRTBackend::RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run_info,

View File

@ -41,9 +41,7 @@ using ActorInfo = runtime::ActorInfo;
using GraphCompiler = runtime::GraphCompiler;
using GraphCompilerInfo = runtime::GraphCompilerInfo;
using ControlNodeParser = runtime::ControlNodeParser;
using FrontToBackendNodeWithContext = runtime::FrontToBackendNodeWithContext;
using FuncGraphToParameter = runtime::FuncGraphToParameter;
using HostParameterToWeight = runtime::HostParameterToWeight;
using ControlNodeParserPtr = runtime::ControlNodeParserPtr;
enum SwitchCondStatus {
kCondOk = 0,