!19184 Fix weight input of gather actor.
Merge pull request !19184 from gaoyong10/new_runtime18
This commit is contained in:
commit
0ac3cd3aef
|
@ -410,6 +410,46 @@ AnfNodePtr FetchSourceNodeByAutoMonad(const AnfNodePtr &node) {
|
|||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
// Fetch all parameters in control node of root funcgraph.
|
||||
std::vector<AnfNodePtr> FetchParameterByControlNode(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
|
||||
for (const auto &control_node : control_nodes) {
|
||||
CNodePtr cnode = control_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
|
||||
break;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
|
||||
for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->isa<Parameter>()) {
|
||||
parameters.emplace_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
} else if (cnode->input(0)->isa<CNode>() || IsValueNode<FuncGraph>(cnode->input(0))) {
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->isa<Parameter>()) {
|
||||
parameters.emplace_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) {
|
||||
if (inputs.size() != kSwitchInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
|
||||
}
|
||||
if (inputs[kSwitchCondPos]->isa<Parameter>()) {
|
||||
parameters.emplace_back(inputs[kSwitchCondPos]);
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
|
||||
if (inputs.size() != kSwitchLayerInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
|
||||
}
|
||||
if (inputs[kSwitchLayerCondPos]->isa<Parameter>()) {
|
||||
parameters.emplace_back(inputs[kSwitchLayerCondPos]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return parameters;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Return true if the node has Ref abstract.
|
||||
|
@ -430,20 +470,25 @@ bool IsCallNode(const AnfNodePtr &node) {
|
|||
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FetchInputsByMakeTuple(const AnfNodePtr &node) {
|
||||
std::vector<AnfNodePtr> FetchAllRealInputNodeByParameter(const AnfNodePtr &node) {
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
const auto ¶meter = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
|
||||
parameters.emplace_back(parameter);
|
||||
return parameters;
|
||||
}
|
||||
const auto real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
|
||||
|
||||
const auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (real_node->isa<Parameter>()) {
|
||||
if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) {
|
||||
parameters.emplace_back(real_node);
|
||||
}
|
||||
} else if (HasAbstractMonad(real_node)) {
|
||||
return parameters;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
|
||||
const auto &inputs = real_node->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
|
||||
const auto &sub_parameters = FetchInputsByMakeTuple(inputs[i]);
|
||||
const auto &sub_parameters = FetchAllRealInputNodeByParameter(inputs[i]);
|
||||
parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
|
||||
}
|
||||
|
||||
} else {
|
||||
parameters.emplace_back(real_node);
|
||||
}
|
||||
return parameters;
|
||||
}
|
||||
|
||||
|
@ -622,12 +667,12 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
|||
|
||||
FetchFrontToBackendKernel(graphs, device_contexts);
|
||||
|
||||
FetchCallInputKernelGraph(graphs, device_contexts);
|
||||
|
||||
control_node_parameters_ = FetchControlNodeParameter(control_nodes, device_contexts[0]);
|
||||
|
||||
FetchFuncGraphCallNum(control_nodes);
|
||||
|
||||
FetchCallInputKernelGraph(graphs, device_contexts);
|
||||
|
||||
FetchBackendInputNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters);
|
||||
|
||||
FetchAutoMonadNode(control_nodes);
|
||||
|
@ -907,38 +952,24 @@ void ControlNodeParser::FetchFrontToFrontParameter(
|
|||
|
||||
std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes,
|
||||
DeviceContext *device_context) {
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
std::vector<AnfNodePtr> parameters = FetchParameterByControlNode(control_nodes);
|
||||
|
||||
for (const auto &control_node : control_nodes) {
|
||||
CNodePtr cnode = control_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
|
||||
break;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
|
||||
for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->isa<Parameter>()) {
|
||||
parameters.emplace_back(inputs[i]);
|
||||
for (const auto &graph_with_device_context : call_input_kernel_graphs_) {
|
||||
const auto &graph = graph_with_device_context.first;
|
||||
const auto &func_graph = graph->GetFuncGraph();
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(WARNING) << "Cannot get funcgraph by kernel graph:" << graph->ToString();
|
||||
continue;
|
||||
}
|
||||
if (func_graph != root_func_graph_) {
|
||||
continue;
|
||||
}
|
||||
} else if (cnode->input(0)->isa<CNode>() || IsValueNode<FuncGraph>(cnode->input(0))) {
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->isa<Parameter>()) {
|
||||
parameters.emplace_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) {
|
||||
if (inputs.size() != kSwitchInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
|
||||
}
|
||||
if (inputs[kSwitchCondPos]->isa<Parameter>()) {
|
||||
parameters.emplace_back(inputs[kSwitchCondPos]);
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
|
||||
if (inputs.size() != kSwitchLayerInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
|
||||
}
|
||||
if (inputs[kSwitchLayerCondPos]->isa<Parameter>()) {
|
||||
parameters.emplace_back(inputs[kSwitchLayerCondPos]);
|
||||
|
||||
const auto &inputs = graph->input_nodes();
|
||||
for (const auto &input : inputs) {
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(input);
|
||||
if (front_node != nullptr && front_node->isa<Parameter>() && (!HasAbstractRef(front_node))) {
|
||||
parameters.emplace_back(front_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1056,31 +1087,19 @@ std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph)
|
|||
const auto &graph_parameters = graph->input_nodes();
|
||||
|
||||
for (const auto &graph_parameter : graph_parameters) {
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(graph_parameter);
|
||||
if (front_node != nullptr) {
|
||||
parameters.emplace_back(front_node);
|
||||
continue;
|
||||
}
|
||||
const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(graph_parameter);
|
||||
if (front_node_with_index.first == nullptr) {
|
||||
const auto &external_front_node = graph->GetFrontAnfByBackendAnf(graph_parameter);
|
||||
const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(graph_parameter).first;
|
||||
|
||||
if (external_front_node == nullptr && internal_front_node == nullptr) {
|
||||
MS_LOG(WARNING) << "Invalid parameter of kernel graph, parameter :"
|
||||
<< AnfAlgo::GetNodeDebugString(graph_parameter);
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto real_front_node = AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, 0).first;
|
||||
if ((real_front_node->isa<Parameter>() && HasAbstractRef(real_front_node)) ||
|
||||
HasAbstractMonad(front_node_with_index.first)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(front_node_with_index.first, prim::kPrimMakeTuple)) {
|
||||
const auto &sub_parameters = FetchInputsByMakeTuple(front_node_with_index.first);
|
||||
const auto &front_node = (external_front_node != nullptr) ? external_front_node : internal_front_node;
|
||||
const auto real_front_node = AnfAlgo::VisitKernelWithReturnType(front_node, 0).first;
|
||||
const auto &sub_parameters = FetchAllRealInputNodeByParameter(real_front_node);
|
||||
parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
|
||||
continue;
|
||||
}
|
||||
|
||||
parameters.emplace_back(front_node_with_index.first);
|
||||
}
|
||||
|
||||
return parameters;
|
||||
|
|
|
@ -1152,6 +1152,12 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
|
|||
auto front_node = GetFrontNodeByBackendNode(from_kernel);
|
||||
|
||||
if (from_kernel->isa<Parameter>() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) {
|
||||
if (HasAbstractRef(from_kernel)) {
|
||||
const auto devcie_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
|
||||
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, devcie_tensor_store_key.get());
|
||||
return;
|
||||
}
|
||||
|
||||
// When there is a call input in the kernel graph, all the inputs of the kernel graph needs to be sent by gather.
|
||||
const auto actor_name = graph->ToString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
|
|
Loading…
Reference in New Issue