!19184 Fix weight input of gather actor.

Merge pull request !19184 from gaoyong10/new_runtime18
This commit is contained in:
i-robot 2021-07-01 12:00:12 +00:00 committed by Gitee
commit 0ac3cd3aef
2 changed files with 89 additions and 64 deletions

View File

@ -410,6 +410,46 @@ AnfNodePtr FetchSourceNodeByAutoMonad(const AnfNodePtr &node) {
}
return node;
}
// Fetch all parameters in control node of root funcgraph.
std::vector<AnfNodePtr> FetchParameterByControlNode(const std::vector<AnfNodePtr> &control_nodes) {
std::vector<AnfNodePtr> parameters;
for (const auto &control_node : control_nodes) {
CNodePtr cnode = control_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
break;
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
if (inputs[i]->isa<Parameter>()) {
parameters.emplace_back(inputs[i]);
}
}
} else if (cnode->input(0)->isa<CNode>() || IsValueNode<FuncGraph>(cnode->input(0))) {
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (inputs[i]->isa<Parameter>()) {
parameters.emplace_back(inputs[i]);
}
}
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) {
if (inputs.size() != kSwitchInputNum) {
MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
}
if (inputs[kSwitchCondPos]->isa<Parameter>()) {
parameters.emplace_back(inputs[kSwitchCondPos]);
}
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
if (inputs.size() != kSwitchLayerInputNum) {
MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
}
if (inputs[kSwitchLayerCondPos]->isa<Parameter>()) {
parameters.emplace_back(inputs[kSwitchLayerCondPos]);
}
}
}
return parameters;
}
} // namespace
// Return true if the node has Ref abstract.
@ -430,20 +470,25 @@ bool IsCallNode(const AnfNodePtr &node) {
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
}
std::vector<AnfNodePtr> FetchInputsByMakeTuple(const AnfNodePtr &node) {
std::vector<AnfNodePtr> FetchAllRealInputNodeByParameter(const AnfNodePtr &node) {
std::vector<AnfNodePtr> parameters;
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
const auto &parameter = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
parameters.emplace_back(parameter);
return parameters;
}
const auto real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
const auto &inputs = node->cast<CNodePtr>()->inputs();
if (real_node->isa<Parameter>()) {
if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) {
parameters.emplace_back(real_node);
}
} else if (HasAbstractMonad(real_node)) {
return parameters;
} else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
const auto &inputs = real_node->cast<CNodePtr>()->inputs();
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
const auto &sub_parameters = FetchInputsByMakeTuple(inputs[i]);
const auto &sub_parameters = FetchAllRealInputNodeByParameter(inputs[i]);
parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
}
} else {
parameters.emplace_back(real_node);
}
return parameters;
}
@ -622,12 +667,12 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
FetchFrontToBackendKernel(graphs, device_contexts);
FetchCallInputKernelGraph(graphs, device_contexts);
control_node_parameters_ = FetchControlNodeParameter(control_nodes, device_contexts[0]);
FetchFuncGraphCallNum(control_nodes);
FetchCallInputKernelGraph(graphs, device_contexts);
FetchBackendInputNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters);
FetchAutoMonadNode(control_nodes);
@ -907,38 +952,24 @@ void ControlNodeParser::FetchFrontToFrontParameter(
std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes,
DeviceContext *device_context) {
std::vector<AnfNodePtr> parameters;
std::vector<AnfNodePtr> parameters = FetchParameterByControlNode(control_nodes);
for (const auto &control_node : control_nodes) {
CNodePtr cnode = control_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
break;
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
if (inputs[i]->isa<Parameter>()) {
parameters.emplace_back(inputs[i]);
for (const auto &graph_with_device_context : call_input_kernel_graphs_) {
const auto &graph = graph_with_device_context.first;
const auto &func_graph = graph->GetFuncGraph();
if (func_graph == nullptr) {
MS_LOG(WARNING) << "Cannot get funcgraph by kernel graph:" << graph->ToString();
continue;
}
if (func_graph != root_func_graph_) {
continue;
}
} else if (cnode->input(0)->isa<CNode>() || IsValueNode<FuncGraph>(cnode->input(0))) {
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (inputs[i]->isa<Parameter>()) {
parameters.emplace_back(inputs[i]);
}
}
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) {
if (inputs.size() != kSwitchInputNum) {
MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
}
if (inputs[kSwitchCondPos]->isa<Parameter>()) {
parameters.emplace_back(inputs[kSwitchCondPos]);
}
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
if (inputs.size() != kSwitchLayerInputNum) {
MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
}
if (inputs[kSwitchLayerCondPos]->isa<Parameter>()) {
parameters.emplace_back(inputs[kSwitchLayerCondPos]);
const auto &inputs = graph->input_nodes();
for (const auto &input : inputs) {
const auto &front_node = graph->GetFrontAnfByBackendAnf(input);
if (front_node != nullptr && front_node->isa<Parameter>() && (!HasAbstractRef(front_node))) {
parameters.emplace_back(front_node);
}
}
}
@ -1056,31 +1087,19 @@ std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph)
const auto &graph_parameters = graph->input_nodes();
for (const auto &graph_parameter : graph_parameters) {
const auto &front_node = graph->GetFrontAnfByBackendAnf(graph_parameter);
if (front_node != nullptr) {
parameters.emplace_back(front_node);
continue;
}
const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(graph_parameter);
if (front_node_with_index.first == nullptr) {
const auto &external_front_node = graph->GetFrontAnfByBackendAnf(graph_parameter);
const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(graph_parameter).first;
if (external_front_node == nullptr && internal_front_node == nullptr) {
MS_LOG(WARNING) << "Invalid parameter of kernel graph, parameter :"
<< AnfAlgo::GetNodeDebugString(graph_parameter);
continue;
}
const auto real_front_node = AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, 0).first;
if ((real_front_node->isa<Parameter>() && HasAbstractRef(real_front_node)) ||
HasAbstractMonad(front_node_with_index.first)) {
continue;
}
if (AnfAlgo::CheckPrimitiveType(front_node_with_index.first, prim::kPrimMakeTuple)) {
const auto &sub_parameters = FetchInputsByMakeTuple(front_node_with_index.first);
const auto &front_node = (external_front_node != nullptr) ? external_front_node : internal_front_node;
const auto real_front_node = AnfAlgo::VisitKernelWithReturnType(front_node, 0).first;
const auto &sub_parameters = FetchAllRealInputNodeByParameter(real_front_node);
parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
continue;
}
parameters.emplace_back(front_node_with_index.first);
}
return parameters;

View File

@ -1152,6 +1152,12 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
auto front_node = GetFrontNodeByBackendNode(from_kernel);
if (from_kernel->isa<Parameter>() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) {
if (HasAbstractRef(from_kernel)) {
const auto devcie_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, devcie_tensor_store_key.get());
return;
}
// When there is a call input in the kernel graph, all the inputs of the kernel graph needs to be sent by gather.
const auto actor_name = graph->ToString();
auto actor = FetchActor(actor_name);