Ignore Partial(DeadNode) in backend routine.
This commit is contained in:
parent
ff7646453a
commit
f764f15278
|
@ -373,7 +373,6 @@ py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symb
|
|||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph is nullptr.";
|
||||
}
|
||||
SymbolResolver symbol_resolver(name_space, symbol, node);
|
||||
symbol_resolver.Resolve();
|
||||
if (!symbol_resolver.Resolve()) {
|
||||
MS_LOG(EXCEPTION) << "Fail to resolve node, NodeInfo.";
|
||||
}
|
||||
|
|
|
@ -1000,16 +1000,19 @@ void ControlNodeParser::ParseDeviceContextForPartialNode(const std::vector<AnfNo
|
|||
if (inputs.size() <= kPartialFuncGraphPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input size for partial node:" << cnode->DebugString();
|
||||
}
|
||||
|
||||
auto &func_node = inputs[kPartialFuncGraphPos];
|
||||
// Ignore if the node is 'Partial(DeadNode,)'.
|
||||
auto func_value = GetValueNode<StringImmPtr>(func_node);
|
||||
if (func_value != nullptr && func_value->value() == kDeadNodeName) {
|
||||
MS_LOG(DEBUG) << "Ignore partial dead node:" << cnode->DebugString();
|
||||
continue;
|
||||
}
|
||||
// Fetch the funcgraph in partial node.
|
||||
const auto &func_graph_node = inputs[kPartialFuncGraphPos];
|
||||
MS_EXCEPTION_IF_NULL(func_graph_node);
|
||||
if ((!func_graph_node->isa<ValueNode>()) || (!IsValueNode<FuncGraph>(func_graph_node))) {
|
||||
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_graph_node->DebugString()
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_node->DebugString()
|
||||
<< " for partial node:" << cnode->DebugString();
|
||||
}
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_graph_node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
// Fetch the device contexts for the formal parameters in the funcgraph of partial node.
|
||||
auto iter = func_graph_to_device_contexts_.find(func_graph);
|
||||
|
@ -1325,12 +1328,21 @@ void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr>
|
|||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() <= kPartialFuncGraphPos || (!inputs[kPartialFuncGraphPos]->isa<ValueNode>()) ||
|
||||
(!IsValueNode<FuncGraph>(inputs[kPartialFuncGraphPos]))) {
|
||||
MS_LOG(EXCEPTION) << "Invalid partial node:" << node->DebugString();
|
||||
if (inputs.size() <= kPartialFuncGraphPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input size for partial node:" << node->DebugString();
|
||||
}
|
||||
auto &func_node = inputs[kPartialFuncGraphPos];
|
||||
// Ignore if the node is 'Partial(DeadNode,)'.
|
||||
auto func_value = GetValueNode<StringImmPtr>(func_node);
|
||||
if (func_value != nullptr && func_value->value() == kDeadNodeName) {
|
||||
MS_LOG(DEBUG) << "Ignore partial dead node:" << node->DebugString();
|
||||
continue;
|
||||
}
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_node->DebugString()
|
||||
<< " for partial node:" << node->DebugString();
|
||||
}
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[kPartialFuncGraphPos]);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto ¶meters = func_graph->parameters();
|
||||
if (inputs.size() - kPartialInputStartPos > parameters.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid partial input size:" << inputs.size()
|
||||
|
|
|
@ -138,6 +138,30 @@ std::vector<SwitchActorPtr> ControlNodeScheduler::BuildSwitchActor(const GraphCo
|
|||
return switch_actors;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool IsValidPartialCNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() <= kPartialFuncGraphPos) {
|
||||
return false;
|
||||
}
|
||||
if (!IsPrimitive(inputs[kAnfPrimitiveIndex], prim::kPrimPartial)) {
|
||||
return false;
|
||||
}
|
||||
// Ignore if the node is 'Partial(DeadNode,)'.
|
||||
auto func_value = GetValueNode<StringImmPtr>(inputs[kPartialFuncGraphPos]);
|
||||
if (func_value != nullptr && func_value->value() == kDeadNodeName) {
|
||||
MS_LOG(DEBUG) << "Ignore partial dead node:" << cnode->DebugString();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
std::vector<GatherActorPtr> gather_actors;
|
||||
const auto &control_nodes = graph_compiler_info.control_nodes_;
|
||||
|
@ -146,7 +170,7 @@ std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCo
|
|||
|
||||
for (const auto &control_node : control_nodes) {
|
||||
// Partial node and call node will be converted to gather actor.
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) || AnfAlgo::IsCallNode(control_node)) {
|
||||
if (IsValidPartialCNode(control_node) || AnfAlgo::IsCallNode(control_node)) {
|
||||
const auto &actor_name = GetActorName(control_node);
|
||||
const auto ¶meters = FetchInputNodeByCNode(control_node);
|
||||
const auto &gather_actor =
|
||||
|
@ -634,7 +658,10 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
|
|||
// Link arrow from gather actor
|
||||
const auto &actor_name = GetActorName(from_node);
|
||||
const auto &actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
if (actor == nullptr) {
|
||||
MS_LOG(DEBUG) << "No actor of " << actor_name;
|
||||
return;
|
||||
}
|
||||
const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
LinkPartialArrow(gather_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
|
||||
|
|
Loading…
Reference in New Issue