Ignore Partial(DeadNode) in backend routine.

This commit is contained in:
Zhang Qinghua 2022-01-13 17:03:48 +08:00
parent ff7646453a
commit f764f15278
3 changed files with 53 additions and 15 deletions

View File

@ -373,7 +373,6 @@ py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symb
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph is nullptr.";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
symbol_resolver.Resolve();
if (!symbol_resolver.Resolve()) {
MS_LOG(EXCEPTION) << "Fail to resolve node, NodeInfo.";
}

View File

@ -1000,16 +1000,19 @@ void ControlNodeParser::ParseDeviceContextForPartialNode(const std::vector<AnfNo
if (inputs.size() <= kPartialFuncGraphPos) {
MS_LOG(EXCEPTION) << "Invalid input size for partial node:" << cnode->DebugString();
}
auto &func_node = inputs[kPartialFuncGraphPos];
// Ignore if the node is 'Partial(DeadNode,)'.
auto func_value = GetValueNode<StringImmPtr>(func_node);
if (func_value != nullptr && func_value->value() == kDeadNodeName) {
MS_LOG(DEBUG) << "Ignore partial dead node:" << cnode->DebugString();
continue;
}
// Fetch the funcgraph in partial node.
const auto &func_graph_node = inputs[kPartialFuncGraphPos];
MS_EXCEPTION_IF_NULL(func_graph_node);
if ((!func_graph_node->isa<ValueNode>()) || (!IsValueNode<FuncGraph>(func_graph_node))) {
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_graph_node->DebugString()
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_node->DebugString()
<< " for partial node:" << cnode->DebugString();
}
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_graph_node);
MS_EXCEPTION_IF_NULL(func_graph);
// Fetch the device contexts for the formal parameters in the funcgraph of partial node.
auto iter = func_graph_to_device_contexts_.find(func_graph);
@ -1325,12 +1328,21 @@ void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr>
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
if (inputs.size() <= kPartialFuncGraphPos || (!inputs[kPartialFuncGraphPos]->isa<ValueNode>()) ||
(!IsValueNode<FuncGraph>(inputs[kPartialFuncGraphPos]))) {
MS_LOG(EXCEPTION) << "Invalid partial node:" << node->DebugString();
if (inputs.size() <= kPartialFuncGraphPos) {
MS_LOG(EXCEPTION) << "Invalid input size for partial node:" << node->DebugString();
}
auto &func_node = inputs[kPartialFuncGraphPos];
// Ignore if the node is 'Partial(DeadNode,)'.
auto func_value = GetValueNode<StringImmPtr>(func_node);
if (func_value != nullptr && func_value->value() == kDeadNodeName) {
MS_LOG(DEBUG) << "Ignore partial dead node:" << node->DebugString();
continue;
}
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_node->DebugString()
<< " for partial node:" << node->DebugString();
}
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[kPartialFuncGraphPos]);
MS_EXCEPTION_IF_NULL(func_graph);
const auto &parameters = func_graph->parameters();
if (inputs.size() - kPartialInputStartPos > parameters.size()) {
MS_LOG(EXCEPTION) << "Invalid partial input size:" << inputs.size()

View File

@ -138,6 +138,30 @@ std::vector<SwitchActorPtr> ControlNodeScheduler::BuildSwitchActor(const GraphCo
return switch_actors;
}
namespace {
bool IsValidPartialCNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr) {
return false;
}
const auto &inputs = cnode->inputs();
if (inputs.size() <= kPartialFuncGraphPos) {
return false;
}
if (!IsPrimitive(inputs[kAnfPrimitiveIndex], prim::kPrimPartial)) {
return false;
}
// Ignore if the node is 'Partial(DeadNode,)'.
auto func_value = GetValueNode<StringImmPtr>(inputs[kPartialFuncGraphPos]);
if (func_value != nullptr && func_value->value() == kDeadNodeName) {
MS_LOG(DEBUG) << "Ignore partial dead node:" << cnode->DebugString();
return false;
}
return true;
}
} // namespace
std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
std::vector<GatherActorPtr> gather_actors;
const auto &control_nodes = graph_compiler_info.control_nodes_;
@ -146,7 +170,7 @@ std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCo
for (const auto &control_node : control_nodes) {
// Partial node and call node will be converted to gather actor.
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) || AnfAlgo::IsCallNode(control_node)) {
if (IsValidPartialCNode(control_node) || AnfAlgo::IsCallNode(control_node)) {
const auto &actor_name = GetActorName(control_node);
const auto &parameters = FetchInputNodeByCNode(control_node);
const auto &gather_actor =
@ -634,7 +658,10 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
// Link arrow from gather actor
const auto &actor_name = GetActorName(from_node);
const auto &actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
if (actor == nullptr) {
MS_LOG(DEBUG) << "No actor of " << actor_name;
return;
}
const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
LinkPartialArrow(gather_actor, to_actor, from_node_with_index.second, to_node_with_index.second);