Fix ocean and sponge net.
This commit is contained in:
parent
0022d07d6e
commit
03021379d4
|
@ -339,6 +339,7 @@ class KernelGraph : public FuncGraph {
|
|||
|
||||
bool is_all_nop_node() const { return is_all_nop_node_; }
|
||||
void set_is_all_nop_node(bool is_all_nop_node) { is_all_nop_node_ = is_all_nop_node; }
|
||||
std::map<AnfWithOutIndex, AnfWithOutIndex> graph_output_map() { return graph_output_to_front_node_map_; }
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
|
|
|
@ -114,7 +114,7 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
|||
size_t from_index = result_arrow->from_output_index_;
|
||||
const auto &front_node = data_nodes_[from_index];
|
||||
for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) {
|
||||
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() ==
|
||||
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second, false).get() ==
|
||||
input_device_tensors_[from_index]) {
|
||||
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
|
||||
result_arrow->to_input_index_, context);
|
||||
|
|
|
@ -397,16 +397,34 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
for (size_t i = 0; i < output_branch_result_arrow.size(); ++i) {
|
||||
auto &result_arrow = output_branch_result_arrow[i];
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
size_t from_index = result_arrow->from_output_index_;
|
||||
if (result_arrow->from_output_index_ >= SizeToInt(branch_inputs_pos_[index].size())) {
|
||||
MS_LOG(EXCEPTION) << "Invalid from index in switch actor, from index:" << result_arrow->from_output_index_
|
||||
<< " total:" << branch_inputs_pos_[index].size() << " actor:" << GetAID();
|
||||
}
|
||||
size_t from_index = branch_inputs_pos_[index][result_arrow->from_output_index_];
|
||||
|
||||
MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index];
|
||||
bool is_send = false;
|
||||
for (const auto &backend_node : backend_parameters_[from_index]) {
|
||||
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() ==
|
||||
input_device_tensors_[from_index]) {
|
||||
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
|
||||
for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(backend_node.first); ++j) {
|
||||
if (AnfAlgo::OutputAddrExist(backend_node.first, j, false) &&
|
||||
AnfAlgo::GetMutableOutputAddr(backend_node.first, j, false).get() == input_device_tensors_[from_index]) {
|
||||
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, j,
|
||||
result_arrow->to_input_index_, context);
|
||||
is_send = true;
|
||||
MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index]
|
||||
<< " succeed";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!is_send) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get backend node of switch actor output, actor:" << GetAID()
|
||||
<< " branch:" << index << " index:" << result_arrow->from_output_index_ << " output pos"
|
||||
<< branch_inputs_pos_[index][result_arrow->from_output_index_] << " output index"
|
||||
<< result_arrow->to_input_index_;
|
||||
}
|
||||
}
|
||||
|
||||
// Send output control.
|
||||
auto source_aid = const_cast<AID *>(&GetAID());
|
||||
|
|
|
@ -166,6 +166,7 @@ void CreateDeviceTensorForValueNode(const AnfNodePtr &front_node, const AnfNodeP
|
|||
device::DeviceAddressPtr address =
|
||||
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address;
|
||||
AnfAlgo::SetOutputAddr(address, 0, front_node.get());
|
||||
}
|
||||
|
||||
|
@ -190,6 +191,7 @@ void CreateDeviceTensorForFrontParameter(const AnfNodePtr &node, const DeviceCon
|
|||
// Create device tensor.
|
||||
device::DeviceAddressPtr address = device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(node) << " addr:" << address;
|
||||
AnfAlgo::SetOutputAddr(address, 0, node.get());
|
||||
}
|
||||
|
||||
|
@ -1217,11 +1219,25 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vector<KernelGraphP
|
|||
if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) {
|
||||
auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
|
||||
if (front_node != nullptr) {
|
||||
front_to_backend_kernels_[front_node] = {kernel, device_context};
|
||||
for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(kernel); ++j) {
|
||||
front_to_backend_kernels_[{front_node, j}] = {{kernel, j}, device_context};
|
||||
MS_LOG(DEBUG) << "Add front to backend kernel, front:" << AnfAlgo::GetNodeDebugString(front_node)
|
||||
<< "index:" << j << " addr:" << front_node
|
||||
<< " second:" << AnfAlgo::GetNodeDebugString(kernel) << "index:" << j << " addr:" << kernel;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto graph_output_map = graph->graph_output_map();
|
||||
for (const auto &output_pair : graph_output_map) {
|
||||
front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context};
|
||||
MS_LOG(DEBUG) << "Add front to backend kernel, front:" << AnfAlgo::GetNodeDebugString(output_pair.second.first)
|
||||
<< "index:" << output_pair.second.second << " addr:" << output_pair.second.first
|
||||
<< " second:" << AnfAlgo::GetNodeDebugString(output_pair.first.first)
|
||||
<< "index:" << output_pair.first.second << " addr:" << output_pair.first.first;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output,
|
||||
|
@ -1230,6 +1246,12 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
|
|||
std::set<KernelWithIndex> *results) {
|
||||
if (front_output->isa<ValueNode>()) {
|
||||
(*results).insert({front_output, 0});
|
||||
const auto &iter = formal_to_real_parameters_.find(front_output);
|
||||
if (iter != formal_to_real_parameters_.end()) {
|
||||
for (const auto &node : iter->second) {
|
||||
(*results).insert(node);
|
||||
}
|
||||
}
|
||||
} else if (front_output->isa<Parameter>()) {
|
||||
// Output is a parameter.
|
||||
const auto iter = formal_to_real_parameters_.find(front_output);
|
||||
|
@ -1265,11 +1287,10 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
|
|||
}
|
||||
} else if (front_output->isa<CNode>()) {
|
||||
// Output is a kernel.
|
||||
const auto iter = front_to_backend_kernels_.find(front_output);
|
||||
const auto iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(front_output, 0));
|
||||
|
||||
if (iter != front_to_backend_kernels_.end()) {
|
||||
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(iter->second.first, 0);
|
||||
(*results).insert(output_with_index);
|
||||
(*results).insert(iter->second.first);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output);
|
||||
}
|
||||
|
@ -1298,11 +1319,11 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(
|
|||
}
|
||||
formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, 0});
|
||||
} else {
|
||||
const auto iter = front_to_backend_kernels_.find(node_with_index.first);
|
||||
const auto iter = front_to_backend_kernels_.find(node_with_index);
|
||||
if (iter == front_to_backend_kernels_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find actor of front node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
|
||||
}
|
||||
formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, node_with_index.second});
|
||||
formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first);
|
||||
}
|
||||
}
|
||||
} else if (real_parameter->isa<ValueNode>()) {
|
||||
|
@ -1315,11 +1336,11 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(
|
|||
} else {
|
||||
// Input node is a cnode.
|
||||
const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(real_parameter, 0);
|
||||
const auto iter = front_to_backend_kernels_.find(node_with_index.first);
|
||||
const auto iter = front_to_backend_kernels_.find(node_with_index);
|
||||
if (iter == front_to_backend_kernels_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
|
||||
}
|
||||
formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, node_with_index.second});
|
||||
formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1373,6 +1394,17 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
|
|||
FetchBackendParameterNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters,
|
||||
&front_to_backend_parameters);
|
||||
|
||||
for (size_t i = 0; i < graphs.size(); ++i) {
|
||||
const auto &graph = graphs[i];
|
||||
for (const auto &value_node : graph->graph_value_nodes()) {
|
||||
auto front_node = graph->GetFrontAnfByBackendAnf(value_node);
|
||||
|
||||
if (front_node != nullptr) {
|
||||
formal_to_real_parameters_[front_node].push_back({value_node, 0});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &func_graph_to_parameters : func_graph_to_parameters_) {
|
||||
const auto &func_graph = func_graph_to_parameters.first;
|
||||
std::vector<AnfNodePtr> graph_inputs;
|
||||
|
@ -1418,9 +1450,10 @@ void ControlNodeParser::FetchAutoMonadNode(const std::vector<AnfNodePtr> &contro
|
|||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState)) {
|
||||
const auto &node = FetchSourceNodeByAutoMonad(inputs[i]);
|
||||
const auto &iter = front_to_backend_kernels_.find(node);
|
||||
const auto &iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(node, 0));
|
||||
if (iter != front_to_backend_kernels_.end()) {
|
||||
kernel_to_call_nodes_[iter->second.first] = control_node;
|
||||
kernel_to_call_nodes_[iter->second.first.first] = control_node;
|
||||
MS_LOG(DEBUG) << "Add auto monad control arrow for node:" << AnfAlgo::GetNodeDebugString(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
|
@ -38,6 +39,7 @@ constexpr int kMainBranchID = 0;
|
|||
constexpr int kSubBranchStartID = 1;
|
||||
|
||||
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;
|
||||
using FuncGraphToParameter = std::unordered_map<FuncGraphPtr, std::vector<std::vector<AnfNodePtr>>>;
|
||||
using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
|
||||
using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
|
@ -117,6 +119,10 @@ class ControlNodeParser {
|
|||
// Get the backend node corresponding to the weight node in the subgraph.
|
||||
AnfNodePtr FetchBackendNodebyWeightNode(const AnfNodePtr &node);
|
||||
|
||||
KernelWithIndex GetBackendKernelByFrontKernel(const KernelWithIndex &front_node_with_index) {
|
||||
return front_to_backend_kernels_[front_node_with_index].first;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
|
@ -193,7 +199,7 @@ class ControlNodeParser {
|
|||
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> formal_to_real_parameters_;
|
||||
|
||||
// Relationship between the front and backend of the executable kernel in all kernel graphs.
|
||||
FrontToBackendNodeWithContext front_to_backend_kernels_;
|
||||
FrontToBackendKernelWithContext front_to_backend_kernels_;
|
||||
|
||||
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
|
||||
// the input node of gather.
|
||||
|
|
|
@ -91,6 +91,7 @@ void CreateParameterDeviceAddress(const DeviceContext *device_context, const Ker
|
|||
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
||||
auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
|
||||
AnfAlgo::GetOutputFormat(item, index), output_type_id);
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(item) << " addr:" << device_address;
|
||||
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
||||
}
|
||||
}
|
||||
|
@ -131,6 +132,7 @@ void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, cons
|
|||
|
||||
device::DeviceAddressPtr address =
|
||||
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
|
||||
}
|
||||
|
@ -154,6 +156,7 @@ void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const Ker
|
|||
size_t tensor_size = value.size();
|
||||
auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
|
||||
|
||||
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
|
||||
}
|
||||
|
@ -176,6 +179,7 @@ void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const
|
|||
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
||||
auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address;
|
||||
AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
|
||||
}
|
||||
}
|
||||
|
@ -191,6 +195,7 @@ void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, con
|
|||
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
|
||||
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
|
||||
auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown);
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address;
|
||||
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -613,8 +613,8 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
|
|||
}
|
||||
|
||||
// 3.Prepare the data which belongs to control node.
|
||||
PrepareDataForControlNode(graph_compiler_info.control_node_parser_, graph_compiler_info.origin_parameters_order_,
|
||||
input_tensors.back(), host_data_source_actor->data_node_position_map_, &host_tensors);
|
||||
PrepareDataForControlNode(host_data_source_actor, graph_compiler_info.control_node_parser_,
|
||||
graph_compiler_info.origin_parameters_order_, input_tensors.back(), &host_tensors);
|
||||
|
||||
// 4.Prepare the data of host tensor queue(non weighted parameters of graph).
|
||||
if (host_data_source_actor != nullptr) {
|
||||
|
@ -670,10 +670,10 @@ void GraphScheduler::PrepareRunOp(const ActorSet *actor_set, const GraphCompiler
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser,
|
||||
void GraphScheduler::PrepareDataForControlNode(HostQueueDataSourceActor *host_data_source_actor,
|
||||
const ControlNodeParserPtr &control_node_parser,
|
||||
const std::vector<AnfNodePtr> &origin_parameters,
|
||||
const std::vector<TensorPtr> &tensors,
|
||||
const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
|
||||
std::vector<TensorPtr> *host_tensors) {
|
||||
const auto &control_node_parameters = control_node_parser->GetControlNodeParameter();
|
||||
|
||||
|
@ -692,7 +692,17 @@ void GraphScheduler::PrepareDataForControlNode(const ControlNodeParserPtr &contr
|
|||
PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second,
|
||||
control_node_parser->host_parameter_to_weights_);
|
||||
} else if (find(origin_parameters.begin(), origin_parameters.end(), input_node) != origin_parameters.end()) {
|
||||
PrepareDataForHostDataSourceActor(data_node_position_map, input_node, input_tensor, host_tensors);
|
||||
const auto &iter = host_data_source_actor->data_node_position_map_.find(input_node);
|
||||
if (iter == host_data_source_actor->data_node_position_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find node" << AnfAlgo::GetNodeDebugString(input_node) << " in data source actor";
|
||||
}
|
||||
const size_t pos = iter->second;
|
||||
const AnfNodePtr &backend_node = host_data_source_actor->data_nodes_[pos];
|
||||
(*host_tensors)[pos] = input_tensor;
|
||||
auto device_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
|
||||
if (device_address != nullptr) {
|
||||
AnfAlgo::SetOutputAddr(device_address, 0, backend_node.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -989,6 +999,7 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.second);
|
||||
}
|
||||
}
|
||||
|
||||
return data_source_actors;
|
||||
}
|
||||
|
||||
|
@ -1918,7 +1929,7 @@ void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo
|
|||
}
|
||||
|
||||
for (const auto pos : iter->second) {
|
||||
auto op_arrow = std::make_shared<DataArrow>(input_pos[0], to_actor->GetAID(), pos);
|
||||
auto op_arrow = std::make_shared<DataArrow>(0, to_actor->GetAID(), pos);
|
||||
from_actor->output_branch_result_arrows_[i].emplace_back(op_arrow);
|
||||
}
|
||||
|
||||
|
@ -2211,8 +2222,22 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
|
|||
} else if (IsKernelActor(input_node, graph_compiler_info.strategy_)) {
|
||||
// The actor input is a cnode.
|
||||
if (front_node_to_actor_.find(input_node) == front_node_to_actor_.end()) {
|
||||
const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
|
||||
const auto &backend_node =
|
||||
graph_compiler_info.control_node_parser_->GetBackendKernelByFrontKernel(kernel_with_index);
|
||||
if (backend_node.first == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find actor:" << to_actor->GetAID()
|
||||
<< " input_node:" << AnfAlgo::GetNodeDebugString(input_node);
|
||||
<< " input_node:" << AnfAlgo::GetNodeDebugString(input_node) << " addr:" << input_node;
|
||||
}
|
||||
const auto &actor_name = backend_node.first->fullname_with_scope();
|
||||
const auto &actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index);
|
||||
auto from_actor = dynamic_cast<KernelActor *>(actor);
|
||||
from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, backend_node.second, false);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
return;
|
||||
}
|
||||
|
||||
auto op_arrow = std::make_shared<DataArrow>(input_with_index.second, to_actor->GetAID(), to_index);
|
||||
|
|
|
@ -255,11 +255,10 @@ class GraphScheduler {
|
|||
void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser,
|
||||
void PrepareDataForControlNode(HostQueueDataSourceActor *host_data_source_actor,
|
||||
const ControlNodeParserPtr &control_node_parser,
|
||||
const std::vector<AnfNodePtr> &origin_parameters,
|
||||
const std::vector<TensorPtr> &tensors,
|
||||
const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
|
||||
std::vector<TensorPtr> *host_tensors);
|
||||
const std::vector<TensorPtr> &tensors, std::vector<TensorPtr> *host_tensors);
|
||||
|
||||
// The processing of actors link dynamically.
|
||||
// Analyze necessary input data of current actor, generate and cache op arrow
|
||||
|
|
Loading…
Reference in New Issue