Fix switch actor input tensor.
This commit is contained in:
parent
ebc32c0f46
commit
ab8ac29f0c
|
@ -619,6 +619,55 @@ abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract,
|
|||
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
||||
}
|
||||
|
||||
bool IsPartialInput(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &abstract = node->abstract();
|
||||
if (abstract != nullptr) {
|
||||
if (abstract->isa<abstract::AbstractFunction>()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If the abstract is empty and the node is a cnode, check its true branch.
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() < kSwitchTrueBranchIndex + 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch node:" << node->DebugString();
|
||||
}
|
||||
const auto &branch_node = inputs[kSwitchTrueBranchIndex];
|
||||
MS_EXCEPTION_IF_NULL(branch_node);
|
||||
const auto &branch_abstract = branch_node->abstract();
|
||||
// If abstract is empty, the default is true.
|
||||
if (branch_abstract == nullptr) {
|
||||
MS_LOG(WARNING) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (branch_abstract->isa<abstract::AbstractFunction>()) {
|
||||
return true;
|
||||
} else if (branch_abstract->isa<abstract::AbstractTuple>()) {
|
||||
// In switch layer, the true branch input is a make tuple.
|
||||
auto tuple_abstract = branch_abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
if (sub_abstracts.empty() || sub_abstracts[0] == nullptr) {
|
||||
MS_LOG(WARNING) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
|
||||
return true;
|
||||
}
|
||||
if (sub_abstracts[0]->isa<abstract::AbstractFunction>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph,
|
||||
const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs) {
|
||||
|
@ -1588,13 +1637,25 @@ void ControlNodeParser::ParseFrontToBackendKernel(const std::vector<KernelGraphP
|
|||
|
||||
void ControlNodeParser::ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
for (const auto &control_node : control_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(control_node);
|
||||
const auto &func_graph = control_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// In the funcgraph with recursive call node, the call node is marked as level1, and the entrance actor is
|
||||
// notified to send data after the call node execute ends. At this time, it is necessary to ensure that the
|
||||
// data of all actors in the graph has been processed, so all control nodes of level0 need link control arrow
|
||||
// to entrance actor.
|
||||
if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) {
|
||||
auto iter = node_to_level_.find(control_node);
|
||||
if (iter != node_to_level_.end() && iter->second == 0 && (!IsPartialInput(control_node))) {
|
||||
(void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node);
|
||||
}
|
||||
}
|
||||
|
||||
std::set<AnfNodePtr> checked_nodes;
|
||||
if (((common::AnfAlgo::IsCallNode(control_node) &&
|
||||
unrecursion_call_nodes_.find(control_node) == unrecursion_call_nodes_.end()) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) &&
|
||||
IsFirstControlNode(control_node, &checked_nodes, unrecursion_call_nodes_)) {
|
||||
const auto &func_graph = control_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
(void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node);
|
||||
MS_LOG(DEBUG) << "Add first control node:" << control_node->DebugString()
|
||||
<< " for funcgraph:" << func_graph->ToString();
|
||||
|
|
|
@ -121,6 +121,8 @@ KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index);
|
|||
// Check if the partial node is valid.
|
||||
// Invalid partial nodes are those partial cnodes whose funcgraph is deadnode.
|
||||
bool IsInvalidPartial(const AnfNodePtr &node);
|
||||
// Check whether the switch node abstract is functional.
|
||||
bool IsPartialInput(const AnfNodePtr &node);
|
||||
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
|
||||
class ControlNodeParser {
|
||||
public:
|
||||
|
|
|
@ -162,56 +162,6 @@ bool IsControlArrowExistForCallNode(const AnfNodePtr &node, const AbstractActor
|
|||
const auto &arrows = arrow_iter->second;
|
||||
return std::find(arrows.begin(), arrows.end(), to_actor->GetAID()) != arrows.end();
|
||||
}
|
||||
|
||||
// Check whether the switch node abstract is functional.
|
||||
bool IsPartialInput(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &abstract = node->abstract();
|
||||
if (abstract != nullptr) {
|
||||
if (abstract->isa<abstract::AbstractFunction>()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If the abstract is empty and the node is a cnode, check its true branch.
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() < kSwitchTrueBranchIndex + 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch node:" << node->DebugString();
|
||||
}
|
||||
const auto &branch_node = inputs[kSwitchTrueBranchIndex];
|
||||
MS_EXCEPTION_IF_NULL(branch_node);
|
||||
const auto &branch_abstract = branch_node->abstract();
|
||||
// If abstract is empty, the default is true.
|
||||
if (branch_abstract == nullptr) {
|
||||
MS_LOG(WARNING) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (branch_abstract->isa<abstract::AbstractFunction>()) {
|
||||
return true;
|
||||
} else if (branch_abstract->isa<abstract::AbstractTuple>()) {
|
||||
// In switch layer, the true branch input is a make tuple.
|
||||
auto tuple_abstract = branch_abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
if (sub_abstracts.empty() || sub_abstracts[0] == nullptr) {
|
||||
MS_LOG(WARNING) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
|
||||
return true;
|
||||
}
|
||||
if (sub_abstracts[0]->isa<abstract::AbstractFunction>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info,
|
||||
|
@ -1550,8 +1500,11 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
|
|||
auto to_actor = FetchActor(type, "", kernel, graph);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
size_t from_index = 0;
|
||||
if (common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) {
|
||||
// If the input is a switch node and the graph does not need a stack, then the data arrow needs to be connected
|
||||
// from the switch actor.
|
||||
if ((common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) &&
|
||||
(from_actor->type() != KernelTransformType::kStackActor)) {
|
||||
const auto &actor_name = GetActorName(from_node);
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
|
|
Loading…
Reference in New Issue