From 3522167efd80e2fef42a865cdf7481d60d062603 Mon Sep 17 00:00:00 2001 From: Xun Li Date: Thu, 17 Jun 2021 19:06:10 -0700 Subject: [PATCH] [Coroutine] Properly deal with byval and noalias parameters This patch is to address https://bugs.llvm.org/show_bug.cgi?id=48857. Previous attempts can be found in D104007 and D101980. A lot of discussions can be found in those two patches. To summarize the bug: When Clang emits IR for coroutines, the first thing it does is to make a copy of every argument to the local stack, so that uses of the arguments in the function will all refer to the local copies instead of the arguments directly. However, in some cases we find that arguments are still directly used: When Clang emits IR for a function that has pass-by-value arguments, sometimes it emits an argument with byval attribute. A byval attribute is considered to be local to the function (just like alloca) and hence it can be easily determined that it does not alias other values. If in the IR there exists a memcpy from a byval argument to a local alloca, and then from that local alloca to another alloca, MemCpyOpt will optimize out the first memcpy because byval argument's content will not change. This causes issues because after a coroutine suspension, the byval argument may die outside of the function, and latter uses will lead to memory use-after-free. This is only a problem for arguments with either byval attribute or noalias attribute, because only these two kinds are considered local. Arguments without these two attributes will be considered to alias coro_suspend and hence we won't have this problem. So we need to be able to deal with these two attributes in coroutines properly. For noalias arguments, since coro_suspend may potentially change the value of any argument outside of the function, we simply shouldn't mark any argument in a coroutiune as noalias. This can be taken care of in CoroEarly pass. For byval arguments, if such an argument needs to live across suspensions, we will have to copy their value content to the frame, not just the pointer. Differential Revision: https://reviews.llvm.org/D104184 --- llvm/lib/Transforms/Coroutines/CoroEarly.cpp | 9 ++ llvm/lib/Transforms/Coroutines/CoroFrame.cpp | 31 ++++- .../Transforms/Coroutines/coro-byval-param.ll | 127 ++++++++++++++++++ .../Coroutines/coro-noalias-param.ll | 40 ++++++ 4 files changed, 202 insertions(+), 5 deletions(-) create mode 100644 llvm/test/Transforms/Coroutines/coro-byval-param.ll create mode 100644 llvm/test/Transforms/Coroutines/coro-noalias-param.ll diff --git a/llvm/lib/Transforms/Coroutines/CoroEarly.cpp b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp index 1660e41ba830..5e5e513cdfda 100644 --- a/llvm/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp @@ -149,6 +149,7 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) { bool Changed = false; CoroIdInst *CoroId = nullptr; SmallVector CoroFrees; + bool HasCoroSuspend = false; for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) { Instruction &I = *IB++; if (auto *CB = dyn_cast(&I)) { @@ -163,6 +164,7 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) { // pass expects that there is at most one final suspend point. if (cast(&I)->isFinal()) CB->setCannotDuplicate(); + HasCoroSuspend = true; break; case Intrinsic::coro_end_async: case Intrinsic::coro_end: @@ -213,6 +215,13 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) { if (CoroId) for (CoroFreeInst *CF : CoroFrees) CF->setArgOperand(0, CoroId); + // Coroutine suspention could potentially lead to any argument modified + // outside of the function, hence arguments should not have noalias + // attributes. + if (HasCoroSuspend) + for (Argument &A : F.args()) + if (A.hasNoAliasAttr()) + A.removeAttr(Attribute::NoAlias); return Changed; } diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp index 51cf5b22021c..5dcfc4525e7a 100644 --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -1137,7 +1137,13 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, PromiseAlloca, DenseMap>{}, false); // Create an entry for every spilled value. for (auto &S : FrameData.Spills) { - FieldIDType Id = B.addField(S.first->getType(), None); + Type *FieldType = S.first->getType(); + // For byval arguments, we need to store the pointed value in the frame, + // instead of the pointer itself. + if (const Argument *A = dyn_cast(S.first)) + if (A->hasByValAttr()) + FieldType = FieldType->getPointerElementType(); + FieldIDType Id = B.addField(FieldType, None); FrameData.setFieldIndex(S.first, Id); } @@ -1543,6 +1549,7 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, // Create a store instruction storing the value into the // coroutine frame. Instruction *InsertPt = nullptr; + bool NeedToCopyArgPtrValue = false; if (auto *Arg = dyn_cast(Def)) { // For arguments, we will place the store instruction right after // the coroutine frame pointer instruction, i.e. bitcast of @@ -1553,6 +1560,9 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, // from the coroutine function. Arg->getParent()->removeParamAttr(Arg->getArgNo(), Attribute::NoCapture); + if (Arg->hasByValAttr()) + NeedToCopyArgPtrValue = true; + } else if (auto *CSI = dyn_cast(Def)) { // Don't spill immediately after a suspend; splitting assumes // that the suspend will be followed by a branch. @@ -1587,7 +1597,15 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, Builder.SetInsertPoint(InsertPt); auto *G = Builder.CreateConstInBoundsGEP2_32( FrameTy, FramePtr, 0, Index, Def->getName() + Twine(".spill.addr")); - Builder.CreateStore(Def, G); + if (NeedToCopyArgPtrValue) { + // For byval arguments, we need to store the pointed value in the frame, + // instead of the pointer itself. + auto *Value = + Builder.CreateLoad(Def->getType()->getPointerElementType(), Def); + Builder.CreateStore(Value, G); + } else { + Builder.CreateStore(Def, G); + } BasicBlock *CurrentBlock = nullptr; Value *CurrentReload = nullptr; @@ -1601,9 +1619,12 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, auto *GEP = GetFramePointer(E.first); GEP->setName(E.first->getName() + Twine(".reload.addr")); - CurrentReload = Builder.CreateLoad( - FrameTy->getElementType(FrameData.getFieldIndex(E.first)), GEP, - E.first->getName() + Twine(".reload")); + if (NeedToCopyArgPtrValue) + CurrentReload = GEP; + else + CurrentReload = Builder.CreateLoad( + FrameTy->getElementType(FrameData.getFieldIndex(E.first)), GEP, + E.first->getName() + Twine(".reload")); TinyPtrVector DIs = FindDbgDeclareUses(Def); for (DbgDeclareInst *DDI : DIs) { diff --git a/llvm/test/Transforms/Coroutines/coro-byval-param.ll b/llvm/test/Transforms/Coroutines/coro-byval-param.ll new file mode 100644 index 000000000000..6c3c4582fc8b --- /dev/null +++ b/llvm/test/Transforms/Coroutines/coro-byval-param.ll @@ -0,0 +1,127 @@ +; RUN: opt < %s -passes=coro-split -S | FileCheck %s +%promise_type = type { i8 } +%struct.A = type <{ i64, i64, i32, [4 x i8] }> + +; Function Attrs: noinline ssp uwtable mustprogress +define %promise_type* @foo(%struct.A* nocapture readonly byval(%struct.A) align 8 %a1) #0 { +entry: + %__promise = alloca %promise_type, align 1 + %a2 = alloca %struct.A, align 8 + %0 = getelementptr inbounds %promise_type, %promise_type* %__promise, i64 0, i32 0 + %1 = call token @llvm.coro.id(i32 16, i8* nonnull %0, i8* bitcast (%promise_type* (%struct.A*)* @foo to i8*), i8* null) + %2 = call i1 @llvm.coro.alloc(token %1) + br i1 %2, label %coro.alloc, label %coro.init + +coro.alloc: ; preds = %entry + %3 = call i64 @llvm.coro.size.i64() + %call = call noalias nonnull i8* @_Znwm(i64 %3) #9 + br label %coro.init + +coro.init: ; preds = %coro.alloc, %entry + %4 = phi i8* [ null, %entry ], [ %call, %coro.alloc ] + %5 = call i8* @llvm.coro.begin(token %1, i8* %4) #10 + %6 = bitcast %struct.A* %a1 to i8* + call void @llvm.lifetime.start.p0i8(i64 1, i8* nonnull %0) #2 + %call2 = call %promise_type* @_ZN4task12promise_type17get_return_objectEv(%promise_type* nonnull dereferenceable(1) %__promise) + call void @initial_suspend(%promise_type* nonnull dereferenceable(1) %__promise) + %7 = call token @llvm.coro.save(i8* null) + call fastcc void @_ZNSt12experimental13coroutines_v116coroutine_handleIN4task12promise_typeEE12from_addressEPv(i8* %5) #2 + %8 = call i8 @llvm.coro.suspend(token %7, i1 false) + switch i8 %8, label %coro.ret [ + i8 0, label %init.ready + i8 1, label %cleanup33 + ] + +init.ready: ; preds = %coro.init + %9 = bitcast %struct.A* %a2 to i8* + call void @llvm.lifetime.start.p0i8(i64 24, i8* nonnull %9) #2 + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %9, i8* align 8 %6, i64 24, i1 false) + call void @llvm.lifetime.end.p0i8(i64 24, i8* nonnull %9) #2 + call void @_ZN4task12promise_type13final_suspendEv(%promise_type* nonnull dereferenceable(1) %__promise) #2 + %10 = call token @llvm.coro.save(i8* null) + call fastcc void @_ZNSt12experimental13coroutines_v116coroutine_handleIN4task12promise_typeEE12from_addressEPv(i8* %5) #2 + %11 = call i8 @llvm.coro.suspend(token %10, i1 true) #10 + %switch = icmp ult i8 %11, 2 + br i1 %switch, label %cleanup33, label %coro.ret + +cleanup33: ; preds = %init.ready, %coro.init + call void @llvm.lifetime.end.p0i8(i64 1, i8* nonnull %0) #2 + %12 = call i8* @llvm.coro.free(token %1, i8* %5) + %.not = icmp eq i8* %12, null + br i1 %.not, label %coro.ret, label %coro.free + +coro.free: ; preds = %cleanup33 + call void @_ZdlPv(i8* nonnull %12) #2 + br label %coro.ret + +coro.ret: ; preds = %coro.free, %cleanup33, %init.ready, %coro.init + %13 = call i1 @llvm.coro.end(i8* null, i1 false) #10 + ret %promise_type* %call2 +} + +; check that the frame contains the entire struct, instead of just the struct pointer +; CHECK: %foo.Frame = type { void (%foo.Frame*)*, void (%foo.Frame*)*, %promise_type, %struct.A, i1 } + +; Function Attrs: argmemonly nounwind readonly +declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1 + +; Function Attrs: nounwind +declare i1 @llvm.coro.alloc(token) #2 + +; Function Attrs: nobuiltin nofree allocsize(0) +declare nonnull i8* @_Znwm(i64) local_unnamed_addr #3 + +; Function Attrs: nounwind readnone +declare i64 @llvm.coro.size.i64() #4 + +; Function Attrs: nounwind +declare i8* @llvm.coro.begin(token, i8* writeonly) #2 + +; Function Attrs: argmemonly nofree nosync nounwind willreturn +declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #5 + +; Function Attrs: argmemonly nofree nounwind willreturn +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #6 + +; Function Attrs: noinline nounwind ssp uwtable willreturn mustprogress +declare %promise_type* @_ZN4task12promise_type17get_return_objectEv(%promise_type* nonnull dereferenceable(1)) local_unnamed_addr #7 align 2 + +; Function Attrs: noinline nounwind ssp uwtable willreturn mustprogress +declare void @initial_suspend(%promise_type* nonnull dereferenceable(1)) local_unnamed_addr #7 align 2 + +; Function Attrs: nounwind +declare token @llvm.coro.save(i8*) #2 + +; Function Attrs: noinline nounwind ssp uwtable willreturn mustprogress +declare hidden fastcc void @_ZNSt12experimental13coroutines_v116coroutine_handleIN4task12promise_typeEE12from_addressEPv(i8*) unnamed_addr #7 align 2 + +; Function Attrs: argmemonly nofree nosync nounwind willreturn +declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #5 + +; Function Attrs: nounwind +declare i8 @llvm.coro.suspend(token, i1) #2 + +; Function Attrs: noinline nounwind ssp uwtable willreturn mustprogress +declare void @_ZN4task12promise_type13final_suspendEv(%promise_type* nonnull dereferenceable(1)) local_unnamed_addr #7 align 2 + +; Function Attrs: nounwind +declare i1 @llvm.coro.end(i8*, i1) #2 + +; Function Attrs: nobuiltin nounwind +declare void @_ZdlPv(i8*) local_unnamed_addr #8 + +; Function Attrs: argmemonly nounwind readonly +declare i8* @llvm.coro.free(token, i8* nocapture readonly) #1 + +attributes #0 = { noinline ssp uwtable mustprogress "coroutine.presplit"="1" "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="penryn" "target-features"="+cx16,+cx8,+fxsr,+mmx,+sahf,+sse,+sse2,+sse3,+sse4.1,+ssse3,+x87" "tune-cpu"="generic" } +attributes #1 = { argmemonly nounwind readonly } +attributes #2 = { nounwind } +attributes #3 = { nobuiltin nofree allocsize(0) "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="penryn" "target-features"="+cx16,+cx8,+fxsr,+mmx,+sahf,+sse,+sse2,+sse3,+sse4.1,+ssse3,+x87" "tune-cpu"="generic" } +attributes #4 = { nounwind readnone } +attributes #5 = { argmemonly nofree nosync nounwind willreturn } +attributes #6 = { argmemonly nofree nounwind willreturn } +attributes #7 = { noinline nounwind ssp uwtable willreturn mustprogress "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="penryn" "target-features"="+cx16,+cx8,+fxsr,+mmx,+sahf,+sse,+sse2,+sse3,+sse4.1,+ssse3,+x87" "tune-cpu"="generic" } +attributes #8 = { nobuiltin nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="penryn" "target-features"="+cx16,+cx8,+fxsr,+mmx,+sahf,+sse,+sse2,+sse3,+sse4.1,+ssse3,+x87" "tune-cpu"="generic" } +attributes #9 = { allocsize(0) } +attributes #10 = { noduplicate } + diff --git a/llvm/test/Transforms/Coroutines/coro-noalias-param.ll b/llvm/test/Transforms/Coroutines/coro-noalias-param.ll new file mode 100644 index 000000000000..0b9a70ad0366 --- /dev/null +++ b/llvm/test/Transforms/Coroutines/coro-noalias-param.ll @@ -0,0 +1,40 @@ +; RUN: opt < %s -S -passes=coro-early | FileCheck %s +%struct.A = type <{ i64, i64, i32, [4 x i8] }> + +define void @f(%struct.A* nocapture readonly noalias align 8 %a) { + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %size = call i32 @llvm.coro.size.i32() + %alloc = call i8* @malloc(i32 %size) + %hdl = call i8* @llvm.coro.begin(token %id, i8* %alloc) + call void @print(i32 0) + %s1 = call i8 @llvm.coro.suspend(token none, i1 false) + switch i8 %s1, label %suspend [i8 0, label %resume + i8 1, label %cleanup] +resume: + call void @print(i32 1) + br label %cleanup + +cleanup: + %mem = call i8* @llvm.coro.free(token %id, i8* %hdl) + call void @free(i8* %mem) + br label %suspend +suspend: + call i1 @llvm.coro.end(i8* %hdl, i1 0) + ret void +} + +; check that the noalias attribute is removed from the argument +; CHECK: define void @f(%struct.A* nocapture readonly align 8 %a) + +declare token @llvm.coro.id(i32, i8*, i8*, i8*) +declare i8* @llvm.coro.begin(token, i8*) +declare i8* @llvm.coro.free(token, i8*) +declare i32 @llvm.coro.size.i32() +declare i8 @llvm.coro.suspend(token, i1) +declare void @llvm.coro.resume(i8*) +declare void @llvm.coro.destroy(i8*) +declare i1 @llvm.coro.end(i8*, i1) + +declare noalias i8* @malloc(i32) +declare void @print(i32) +declare void @free(i8*)