!29125 Fix tuple in tuple bug.

Merge pull request !29125 from liangzelang/dev_master
This commit is contained in:
i-robot 2022-01-18 11:25:27 +00:00 committed by Gitee
commit bfdd040728
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 17 additions and 5 deletions

View File

@ -1254,16 +1254,28 @@ class AscendAutoMonadConverter {
return Assign(target, source, link, keep, output);
}
// Assign tuple.
std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem});
std::vector<AnfNodePtr> sources = AnfAlgo::GetAllOutput(source, {prim::kPrimTupleGetItem});
std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target);
std::vector<AnfNodePtr> sources = AnfAlgo::GetAllOutput(source);
if (targets.size() != sources.size()) {
MS_LOG(EXCEPTION) << "Target size " << targets.size() << " != source size " << sources.size();
}
AnfNodePtrList tuple_inputs;
tuple_inputs.reserve(targets.size() + 1);
auto source_item_with_index = AnfAlgo::VisitKernelWithReturnType(source, 0);
MS_EXCEPTION_IF_NULL(source_item_with_index.first);
auto source_cnode = source_item_with_index.first->cast<CNodePtr>();
auto target_cnode = target->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(source_cnode);
MS_EXCEPTION_IF_NULL(target_cnode);
if (!AnfAlgo::CheckPrimitiveType(source_cnode, prim::kPrimMakeTuple)) {
MS_LOG(WARNING) << "Source : " << source_cnode->DebugString() << " is not MakeTuple.";
}
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t i = 0; i < targets.size(); ++i) {
(void)tuple_inputs.emplace_back(Assign(targets[i], sources[i], link, keep, output));
for (size_t i = 1; i < target_cnode->inputs().size(); ++i) {
if (AnfAlgo::IsTupleOutput(target_cnode->input(i))) {
tuple_inputs.emplace_back(AssignAll(target_cnode->input(i), source_cnode->input(i), link, keep, output));
} else {
tuple_inputs.emplace_back(Assign(target_cnode->input(i), source_cnode->input(i), link, keep, output));
}
}
auto new_tuple = kernel_graph_->NewCNode(tuple_inputs);
// Set abstract for the MakeTuple node.