Fix switch actor input tensor.

This commit is contained in:
gaoyong10 2022-06-10 16:00:14 +08:00
parent ebc32c0f46
commit ab8ac29f0c
3 changed files with 70 additions and 54 deletions

View File

@ -619,6 +619,55 @@ abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract,
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
}
bool IsPartialInput(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto &abstract = node->abstract();
if (abstract != nullptr) {
if (abstract->isa<abstract::AbstractFunction>()) {
return true;
}
return false;
}
if (!node->isa<CNode>()) {
return false;
}
// If the abstract is empty and the node is a cnode, check its true branch.
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
if (inputs.size() < kSwitchTrueBranchIndex + 1) {
MS_LOG(EXCEPTION) << "Invalid switch node:" << node->DebugString();
}
const auto &branch_node = inputs[kSwitchTrueBranchIndex];
MS_EXCEPTION_IF_NULL(branch_node);
const auto &branch_abstract = branch_node->abstract();
// If abstract is empty, the default is true.
if (branch_abstract == nullptr) {
MS_LOG(WARNING) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
return true;
}
if (branch_abstract->isa<abstract::AbstractFunction>()) {
return true;
} else if (branch_abstract->isa<abstract::AbstractTuple>()) {
// In switch layer, the true branch input is a make tuple.
auto tuple_abstract = branch_abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
if (sub_abstracts.empty() || sub_abstracts[0] == nullptr) {
MS_LOG(WARNING) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
return true;
}
if (sub_abstracts[0]->isa<abstract::AbstractFunction>()) {
return true;
}
}
return false;
}
void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph,
const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs) {
@ -1588,13 +1637,25 @@ void ControlNodeParser::ParseFrontToBackendKernel(const std::vector<KernelGraphP
void ControlNodeParser::ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
for (const auto &control_node : control_nodes) {
MS_EXCEPTION_IF_NULL(control_node);
const auto &func_graph = control_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
// In the funcgraph with recursive call node, the call node is marked as level1, and the entrance actor is
// notified to send data after the call node execute ends. At this time, it is necessary to ensure that the
// data of all actors in the graph has been processed, so all control nodes of level0 need link control arrow
// to entrance actor.
if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) {
auto iter = node_to_level_.find(control_node);
if (iter != node_to_level_.end() && iter->second == 0 && (!IsPartialInput(control_node))) {
(void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node);
}
}
std::set<AnfNodePtr> checked_nodes;
if (((common::AnfAlgo::IsCallNode(control_node) &&
unrecursion_call_nodes_.find(control_node) == unrecursion_call_nodes_.end()) ||
common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) &&
IsFirstControlNode(control_node, &checked_nodes, unrecursion_call_nodes_)) {
const auto &func_graph = control_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
(void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node);
MS_LOG(DEBUG) << "Add first control node:" << control_node->DebugString()
<< " for funcgraph:" << func_graph->ToString();

View File

@ -121,6 +121,8 @@ KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index);
// Check if the partial node is valid.
// Invalid partial nodes are those partial cnodes whose funcgraph is deadnode.
bool IsInvalidPartial(const AnfNodePtr &node);
// Check whether the switch node abstract is functional.
bool IsPartialInput(const AnfNodePtr &node);
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
class ControlNodeParser {
public:

View File

@ -162,56 +162,6 @@ bool IsControlArrowExistForCallNode(const AnfNodePtr &node, const AbstractActor
const auto &arrows = arrow_iter->second;
return std::find(arrows.begin(), arrows.end(), to_actor->GetAID()) != arrows.end();
}
// Check whether the switch node abstract is functional.
bool IsPartialInput(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto &abstract = node->abstract();
if (abstract != nullptr) {
if (abstract->isa<abstract::AbstractFunction>()) {
return true;
}
return false;
}
if (!node->isa<CNode>()) {
return false;
}
// If the abstract is empty and the node is a cnode, check its true branch.
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
if (inputs.size() < kSwitchTrueBranchIndex + 1) {
MS_LOG(EXCEPTION) << "Invalid switch node:" << node->DebugString();
}
const auto &branch_node = inputs[kSwitchTrueBranchIndex];
MS_EXCEPTION_IF_NULL(branch_node);
const auto &branch_abstract = branch_node->abstract();
// If abstract is empty, the default is true.
if (branch_abstract == nullptr) {
MS_LOG(WARNING) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
return true;
}
if (branch_abstract->isa<abstract::AbstractFunction>()) {
return true;
} else if (branch_abstract->isa<abstract::AbstractTuple>()) {
// In switch layer, the true branch input is a make tuple.
auto tuple_abstract = branch_abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
if (sub_abstracts.empty() || sub_abstracts[0] == nullptr) {
MS_LOG(WARNING) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
return true;
}
if (sub_abstracts[0]->isa<abstract::AbstractFunction>()) {
return true;
}
}
return false;
}
} // namespace
ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info,
@ -1550,8 +1500,11 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
auto to_actor = FetchActor(type, "", kernel, graph);
MS_EXCEPTION_IF_NULL(to_actor);
size_t from_index = 0;
if (common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) {
// If the input is a switch node and the graph does not need a stack, then the data arrow needs to be connected
// from the switch actor.
if ((common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) &&
(from_actor->type() != KernelTransformType::kStackActor)) {
const auto &actor_name = GetActorName(from_node);
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);