Fix ocean and sponge net.

This commit is contained in:
gaoyong10 2021-07-03 12:04:13 +08:00
parent 0022d07d6e
commit 03021379d4
8 changed files with 117 additions and 30 deletions

View File

@ -339,6 +339,7 @@ class KernelGraph : public FuncGraph {
bool is_all_nop_node() const { return is_all_nop_node_; }
void set_is_all_nop_node(bool is_all_nop_node) { is_all_nop_node_ = is_all_nop_node; }
std::map<AnfWithOutIndex, AnfWithOutIndex> graph_output_map() { return graph_output_to_front_node_map_; }
private:
// remove value node form graph

View File

@ -114,7 +114,7 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
size_t from_index = result_arrow->from_output_index_;
const auto &front_node = data_nodes_[from_index];
for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) {
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() ==
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second, false).get() ==
input_device_tensors_[from_index]) {
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
result_arrow->to_input_index_, context);

View File

@ -397,16 +397,34 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
for (size_t i = 0; i < output_branch_result_arrow.size(); ++i) {
auto &result_arrow = output_branch_result_arrow[i];
MS_EXCEPTION_IF_NULL(result_arrow);
size_t from_index = result_arrow->from_output_index_;
if (result_arrow->from_output_index_ >= SizeToInt(branch_inputs_pos_[index].size())) {
MS_LOG(EXCEPTION) << "Invalid from index in switch actor, from index:" << result_arrow->from_output_index_
<< " total:" << branch_inputs_pos_[index].size() << " actor:" << GetAID();
}
size_t from_index = branch_inputs_pos_[index][result_arrow->from_output_index_];
MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index];
bool is_send = false;
for (const auto &backend_node : backend_parameters_[from_index]) {
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() ==
input_device_tensors_[from_index]) {
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(backend_node.first); ++j) {
if (AnfAlgo::OutputAddrExist(backend_node.first, j, false) &&
AnfAlgo::GetMutableOutputAddr(backend_node.first, j, false).get() == input_device_tensors_[from_index]) {
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, j,
result_arrow->to_input_index_, context);
is_send = true;
MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index]
<< " succeed";
break;
}
}
}
if (!is_send) {
MS_LOG(EXCEPTION) << "Failed to get backend node of switch actor output, actor:" << GetAID()
<< " branch:" << index << " index:" << result_arrow->from_output_index_ << " output pos"
<< branch_inputs_pos_[index][result_arrow->from_output_index_] << " output index"
<< result_arrow->to_input_index_;
}
}
// Send output control.
auto source_aid = const_cast<AID *>(&GetAID());

View File

@ -166,6 +166,7 @@ void CreateDeviceTensorForValueNode(const AnfNodePtr &front_node, const AnfNodeP
device::DeviceAddressPtr address =
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
MS_EXCEPTION_IF_NULL(address);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address;
AnfAlgo::SetOutputAddr(address, 0, front_node.get());
}
@ -190,6 +191,7 @@ void CreateDeviceTensorForFrontParameter(const AnfNodePtr &node, const DeviceCon
// Create device tensor.
device::DeviceAddressPtr address = device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id);
MS_EXCEPTION_IF_NULL(address);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(node) << " addr:" << address;
AnfAlgo::SetOutputAddr(address, 0, node.get());
}
@ -1217,11 +1219,25 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vector<KernelGraphP
if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) {
auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
if (front_node != nullptr) {
front_to_backend_kernels_[front_node] = {kernel, device_context};
for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(kernel); ++j) {
front_to_backend_kernels_[{front_node, j}] = {{kernel, j}, device_context};
MS_LOG(DEBUG) << "Add front to backend kernel, front:" << AnfAlgo::GetNodeDebugString(front_node)
<< "index:" << j << " addr:" << front_node
<< " second:" << AnfAlgo::GetNodeDebugString(kernel) << "index:" << j << " addr:" << kernel;
}
}
}
}
const auto graph_output_map = graph->graph_output_map();
for (const auto &output_pair : graph_output_map) {
front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context};
MS_LOG(DEBUG) << "Add front to backend kernel, front:" << AnfAlgo::GetNodeDebugString(output_pair.second.first)
<< "index:" << output_pair.second.second << " addr:" << output_pair.second.first
<< " second:" << AnfAlgo::GetNodeDebugString(output_pair.first.first)
<< "index:" << output_pair.first.second << " addr:" << output_pair.first.first;
}
}
}
void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output,
@ -1230,6 +1246,12 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
std::set<KernelWithIndex> *results) {
if (front_output->isa<ValueNode>()) {
(*results).insert({front_output, 0});
const auto &iter = formal_to_real_parameters_.find(front_output);
if (iter != formal_to_real_parameters_.end()) {
for (const auto &node : iter->second) {
(*results).insert(node);
}
}
} else if (front_output->isa<Parameter>()) {
// Output is a parameter.
const auto iter = formal_to_real_parameters_.find(front_output);
@ -1265,11 +1287,10 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
}
} else if (front_output->isa<CNode>()) {
// Output is a kernel.
const auto iter = front_to_backend_kernels_.find(front_output);
const auto iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(front_output, 0));
if (iter != front_to_backend_kernels_.end()) {
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(iter->second.first, 0);
(*results).insert(output_with_index);
(*results).insert(iter->second.first);
} else {
MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output);
}
@ -1298,11 +1319,11 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(
}
formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, 0});
} else {
const auto iter = front_to_backend_kernels_.find(node_with_index.first);
const auto iter = front_to_backend_kernels_.find(node_with_index);
if (iter == front_to_backend_kernels_.end()) {
MS_LOG(EXCEPTION) << "Cannot find actor of front node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
}
formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, node_with_index.second});
formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first);
}
}
} else if (real_parameter->isa<ValueNode>()) {
@ -1315,11 +1336,11 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(
} else {
// Input node is a cnode.
const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(real_parameter, 0);
const auto iter = front_to_backend_kernels_.find(node_with_index.first);
const auto iter = front_to_backend_kernels_.find(node_with_index);
if (iter == front_to_backend_kernels_.end()) {
MS_LOG(EXCEPTION) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
}
formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, node_with_index.second});
formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first);
}
}
@ -1373,6 +1394,17 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
FetchBackendParameterNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters,
&front_to_backend_parameters);
for (size_t i = 0; i < graphs.size(); ++i) {
const auto &graph = graphs[i];
for (const auto &value_node : graph->graph_value_nodes()) {
auto front_node = graph->GetFrontAnfByBackendAnf(value_node);
if (front_node != nullptr) {
formal_to_real_parameters_[front_node].push_back({value_node, 0});
}
}
}
for (const auto &func_graph_to_parameters : func_graph_to_parameters_) {
const auto &func_graph = func_graph_to_parameters.first;
std::vector<AnfNodePtr> graph_inputs;
@ -1418,9 +1450,10 @@ void ControlNodeParser::FetchAutoMonadNode(const std::vector<AnfNodePtr> &contro
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState)) {
const auto &node = FetchSourceNodeByAutoMonad(inputs[i]);
const auto &iter = front_to_backend_kernels_.find(node);
const auto &iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(node, 0));
if (iter != front_to_backend_kernels_.end()) {
kernel_to_call_nodes_[iter->second.first] = control_node;
kernel_to_call_nodes_[iter->second.first.first] = control_node;
MS_LOG(DEBUG) << "Add auto monad control arrow for node:" << AnfAlgo::GetNodeDebugString(node);
}
}
}

View File

@ -21,6 +21,7 @@
#include <string>
#include <memory>
#include <set>
#include <map>
#include <utility>
#include <unordered_map>
#include <algorithm>
@ -38,6 +39,7 @@ constexpr int kMainBranchID = 0;
constexpr int kSubBranchStartID = 1;
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;
using FuncGraphToParameter = std::unordered_map<FuncGraphPtr, std::vector<std::vector<AnfNodePtr>>>;
using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
@ -117,6 +119,10 @@ class ControlNodeParser {
// Get the backend node corresponding to the weight node in the subgraph.
AnfNodePtr FetchBackendNodebyWeightNode(const AnfNodePtr &node);
KernelWithIndex GetBackendKernelByFrontKernel(const KernelWithIndex &front_node_with_index) {
return front_to_backend_kernels_[front_node_with_index].first;
}
private:
friend class GraphScheduler;
@ -193,7 +199,7 @@ class ControlNodeParser {
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> formal_to_real_parameters_;
// Relationship between the front and backend of the executable kernel in all kernel graphs.
FrontToBackendNodeWithContext front_to_backend_kernels_;
FrontToBackendKernelWithContext front_to_backend_kernels_;
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
// the input node of gather.

View File

@ -91,6 +91,7 @@ void CreateParameterDeviceAddress(const DeviceContext *device_context, const Ker
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
AnfAlgo::GetOutputFormat(item, index), output_type_id);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(item) << " addr:" << device_address;
AnfAlgo::SetOutputAddr(device_address, index, item.get());
}
}
@ -131,6 +132,7 @@ void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, cons
device::DeviceAddressPtr address =
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
MS_EXCEPTION_IF_NULL(address);
AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
}
@ -154,6 +156,7 @@ void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const Ker
size_t tensor_size = value.size();
auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
MS_EXCEPTION_IF_NULL(address);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
}
@ -176,6 +179,7 @@ void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address;
AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
}
}
@ -191,6 +195,7 @@ void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, con
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address;
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
}
}

View File

@ -613,8 +613,8 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
}
// 3.Prepare the data which belongs to control node.
PrepareDataForControlNode(graph_compiler_info.control_node_parser_, graph_compiler_info.origin_parameters_order_,
input_tensors.back(), host_data_source_actor->data_node_position_map_, &host_tensors);
PrepareDataForControlNode(host_data_source_actor, graph_compiler_info.control_node_parser_,
graph_compiler_info.origin_parameters_order_, input_tensors.back(), &host_tensors);
// 4.Prepare the data of host tensor queue(non weighted parameters of graph).
if (host_data_source_actor != nullptr) {
@ -670,10 +670,10 @@ void GraphScheduler::PrepareRunOp(const ActorSet *actor_set, const GraphCompiler
}
}
void GraphScheduler::PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser,
void GraphScheduler::PrepareDataForControlNode(HostQueueDataSourceActor *host_data_source_actor,
const ControlNodeParserPtr &control_node_parser,
const std::vector<AnfNodePtr> &origin_parameters,
const std::vector<TensorPtr> &tensors,
const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
std::vector<TensorPtr> *host_tensors) {
const auto &control_node_parameters = control_node_parser->GetControlNodeParameter();
@ -692,7 +692,17 @@ void GraphScheduler::PrepareDataForControlNode(const ControlNodeParserPtr &contr
PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second,
control_node_parser->host_parameter_to_weights_);
} else if (find(origin_parameters.begin(), origin_parameters.end(), input_node) != origin_parameters.end()) {
PrepareDataForHostDataSourceActor(data_node_position_map, input_node, input_tensor, host_tensors);
const auto &iter = host_data_source_actor->data_node_position_map_.find(input_node);
if (iter == host_data_source_actor->data_node_position_map_.end()) {
MS_LOG(EXCEPTION) << "Cannot find node" << AnfAlgo::GetNodeDebugString(input_node) << " in data source actor";
}
const size_t pos = iter->second;
const AnfNodePtr &backend_node = host_data_source_actor->data_nodes_[pos];
(*host_tensors)[pos] = input_tensor;
auto device_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
if (device_address != nullptr) {
AnfAlgo::SetOutputAddr(device_address, 0, backend_node.get());
}
}
}
@ -989,6 +999,7 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.second);
}
}
return data_source_actors;
}
@ -1918,7 +1929,7 @@ void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo
}
for (const auto pos : iter->second) {
auto op_arrow = std::make_shared<DataArrow>(input_pos[0], to_actor->GetAID(), pos);
auto op_arrow = std::make_shared<DataArrow>(0, to_actor->GetAID(), pos);
from_actor->output_branch_result_arrows_[i].emplace_back(op_arrow);
}
@ -2211,8 +2222,22 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
} else if (IsKernelActor(input_node, graph_compiler_info.strategy_)) {
// The actor input is a cnode.
if (front_node_to_actor_.find(input_node) == front_node_to_actor_.end()) {
const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
const auto &backend_node =
graph_compiler_info.control_node_parser_->GetBackendKernelByFrontKernel(kernel_with_index);
if (backend_node.first == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find actor:" << to_actor->GetAID()
<< " input_node:" << AnfAlgo::GetNodeDebugString(input_node);
<< " input_node:" << AnfAlgo::GetNodeDebugString(input_node) << " addr:" << input_node;
}
const auto &actor_name = backend_node.first->fullname_with_scope();
const auto &actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index);
auto from_actor = dynamic_cast<KernelActor *>(actor);
from_actor->output_data_arrows_.emplace_back(op_arrow);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, backend_node.second, false);
UpdateRefCount(device_tensor.get(), true);
return;
}
auto op_arrow = std::make_shared<DataArrow>(input_with_index.second, to_actor->GetAID(), to_index);

View File

@ -255,11 +255,10 @@ class GraphScheduler {
void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser,
void PrepareDataForControlNode(HostQueueDataSourceActor *host_data_source_actor,
const ControlNodeParserPtr &control_node_parser,
const std::vector<AnfNodePtr> &origin_parameters,
const std::vector<TensorPtr> &tensors,
const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
std::vector<TensorPtr> *host_tensors);
const std::vector<TensorPtr> &tensors, std::vector<TensorPtr> *host_tensors);
// The processing of actors link dynamically.
// Analyze necessary input data of current actor, generate and cache op arrow