diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index 12eb16789825..742e3868796d 100644 --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -509,12 +509,87 @@ static void simplifySuspendPoints(coro::Shape &Shape) { S.resize(N); } +static SmallPtrSet getCoroBeginPredBlocks(CoroBeginInst *CB) { + // Collect all blocks that we need to look for instructions to relocate. + SmallPtrSet RelocBlocks; + SmallVector Work; + Work.push_back(CB->getParent()); + + do { + BasicBlock *Current = Work.pop_back_val(); + for (BasicBlock *BB : predecessors(Current)) + if (RelocBlocks.count(BB) == 0) { + RelocBlocks.insert(BB); + Work.push_back(BB); + } + } while (!Work.empty()); + return RelocBlocks; +} + +static SmallPtrSet +getNotRelocatableInstructions(CoroBeginInst *CoroBegin, + SmallPtrSetImpl &RelocBlocks) { + SmallPtrSet DoNotRelocate; + // Collect all instructions that we should not relocate + SmallVector Work; + + // Start with CoroBegin and terminators of all preceding blocks. + Work.push_back(CoroBegin); + BasicBlock *CoroBeginBB = CoroBegin->getParent(); + for (BasicBlock *BB : RelocBlocks) + if (BB != CoroBeginBB) + Work.push_back(BB->getTerminator()); + + // For every instruction in the Work list, place its operands in DoNotRelocate + // set. + do { + Instruction *Current = Work.pop_back_val(); + DoNotRelocate.insert(Current); + for (Value *U : Current->operands()) { + auto *I = dyn_cast(U); + if (!I) + continue; + if (isa(U)) + continue; + if (DoNotRelocate.count(I) == 0) { + Work.push_back(I); + DoNotRelocate.insert(I); + } + } + } while (!Work.empty()); + return DoNotRelocate; +} + +static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) { + // Analyze which non-alloca instructions are needed for allocation and + // relocate the rest to after coro.begin. We need to do it, since some of the + // targets of those instructions may be placed into coroutine frame memory + // for which becomes available after coro.begin intrinsic. + + auto BlockSet = getCoroBeginPredBlocks(CoroBegin); + auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet); + + Instruction *InsertPt = CoroBegin->getNextNode(); + BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well. + for (auto B = BB.begin(), E = BB.end(); B != E;) { + Instruction &I = *B++; + if (isa(&I)) + continue; + if (&I == CoroBegin) + break; + if (DoNotRelocateSet.count(&I)) + continue; + I.moveBefore(InsertPt); + } +} + static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { coro::Shape Shape(F); if (!Shape.CoroBegin) return; simplifySuspendPoints(Shape); + relocateInstructionBefore(Shape.CoroBegin, F); buildCoroutineFrame(F, Shape); replaceFrameSize(Shape); diff --git a/llvm/test/Transforms/Coroutines/coro-frame.ll b/llvm/test/Transforms/Coroutines/coro-frame.ll index 001012fcd0c9..826d3a04fa1e 100644 --- a/llvm/test/Transforms/Coroutines/coro-frame.ll +++ b/llvm/test/Transforms/Coroutines/coro-frame.ll @@ -1,8 +1,11 @@ ; Check that we can handle spills of the result of the invoke instruction ; RUN: opt < %s -coro-split -S | FileCheck %s -define i8* @f() "coroutine.presplit"="1" personality i32 0 { +define i8* @f(i64 %this) "coroutine.presplit"="1" personality i32 0 { entry: + %this.addr = alloca i64 + store i64 %this, i64* %this.addr + %this1 = load i64, i64* %this.addr %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) %size = call i32 @llvm.coro.size.i32() %alloc = call i8* @malloc(i32 %size) @@ -15,6 +18,7 @@ cont: i8 1, label %cleanup] resume: call double @print(double %r) + call void @print2(i64 %this1) br label %cleanup cleanup: @@ -30,12 +34,12 @@ pad: } ; See if the float was added to the frame -; CHECK-LABEL: %f.Frame = type { void (%f.Frame*)*, void (%f.Frame*)*, i1, i1, double } +; CHECK-LABEL: %f.Frame = type { void (%f.Frame*)*, void (%f.Frame*)*, i1, i1, i64, double } ; See if the float was spilled into the frame ; CHECK-LABEL: @f( ; CHECK: %r = call double @print( -; CHECK: %r.spill.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0, i32 4 +; CHECK: %r.spill.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0, i32 5 ; CHECK: store double %r, double* %r.spill.addr ; CHECK: ret i8* %hdl @@ -58,4 +62,5 @@ declare i1 @llvm.coro.end(i8*, i1) declare noalias i8* @malloc(i32) declare double @print(double) +declare void @print2(i64) declare void @free(i8*)