From d733f2c68c97f7cae9697cffd62aff0ebe79ce16 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Mon, 13 Dec 2021 16:23:15 +0100 Subject: [PATCH] [OpenMPIRBuilder] Support opaque pointers in reduction handling Make the reduction handling in OpenMPIRBuilder compatible with opaque pointers by explicitly storing the element type in ReductionInfo, and also passing it to the atomic reduction callback, as at least the ones in the test need the type there. This doesn't make things fully compatible yet, there are other uses of element types in this class. I also left one getPointerElementType() call in mlir, because I'm not familiar with that area. Differential Revison: https://reviews.llvm.org/D115638 --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 23 +++++---- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 10 ++-- .../Frontend/OpenMPIRBuilderTest.cpp | 48 +++++++++---------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 13 +++-- 4 files changed, 48 insertions(+), 46 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 3263abd4f77c..9976d1961ed1 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -539,24 +539,27 @@ public: function_ref; /// Functions used to generate atomic reductions. Such functions take two - /// Values representing pointers to LHS and RHS of the reduction. They are - /// expected to atomically update the LHS to the reduced value. + /// Values representing pointers to LHS and RHS of the reduction, as well as + /// the element type of these pointers. They are expected to atomically + /// update the LHS to the reduced value. using AtomicReductionGenTy = - function_ref; + function_ref; /// Information about an OpenMP reduction. struct ReductionInfo { - ReductionInfo(Value *Variable, Value *PrivateVariable, + ReductionInfo(Type *ElementType, Value *Variable, Value *PrivateVariable, ReductionGenTy ReductionGen, AtomicReductionGenTy AtomicReductionGen) - : Variable(Variable), PrivateVariable(PrivateVariable), - ReductionGen(ReductionGen), AtomicReductionGen(AtomicReductionGen) {} - - /// Returns the type of the element being reduced. - Type *getElementType() const { - return Variable->getType()->getPointerElementType(); + : ElementType(ElementType), Variable(Variable), + PrivateVariable(PrivateVariable), ReductionGen(ReductionGen), + AtomicReductionGen(AtomicReductionGen) { + assert(cast(Variable->getType()) + ->isOpaqueOrPointeeTypeMatches(ElementType) && "Invalid elem type"); } + /// Reduction element type, must match pointee type of variable. + Type *ElementType; + /// Reduction variable of pointer type. Value *Variable; diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 66cc9dee117e..10634bc7b9fa 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1156,7 +1156,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions( Builder.SetInsertPoint(NonAtomicRedBlock); for (auto En : enumerate(ReductionInfos)) { const ReductionInfo &RI = En.value(); - Type *ValueType = RI.getElementType(); + Type *ValueType = RI.ElementType; Value *RedValue = Builder.CreateLoad(ValueType, RI.Variable, "red.value." + Twine(En.index())); Value *PrivateRedValue = @@ -1181,8 +1181,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions( Builder.SetInsertPoint(AtomicRedBlock); if (CanGenerateAtomic) { for (const ReductionInfo &RI : ReductionInfos) { - Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.Variable, - RI.PrivateVariable)); + Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType, + RI.Variable, RI.PrivateVariable)); if (!Builder.GetInsertBlock()) return InsertPointTy(); } @@ -1207,13 +1207,13 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions( RedArrayTy, LHSArrayPtr, 0, En.index()); Value *LHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), LHSI8PtrPtr); Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType()); - Value *LHS = Builder.CreateLoad(RI.getElementType(), LHSPtr); + Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr); Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64( RedArrayTy, RHSArrayPtr, 0, En.index()); Value *RHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), RHSI8PtrPtr); Value *RHSPtr = Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType()); - Value *RHS = Builder.CreateLoad(RI.getElementType(), RHSPtr); + Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr); Value *Reduced; Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced)); if (!Builder.GetInsertBlock()) diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 8373d69150c1..454dbb58c3a4 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -3028,10 +3028,10 @@ sumReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS, } static OpenMPIRBuilder::InsertPointTy -sumAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS) { +sumAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Type *Ty, Value *LHS, + Value *RHS) { IRBuilder<> Builder(IP.getBlock(), IP.getPoint()); - Value *Partial = Builder.CreateLoad(RHS->getType()->getPointerElementType(), - RHS, "red.partial"); + Value *Partial = Builder.CreateLoad(Ty, RHS, "red.partial"); Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, LHS, Partial, None, AtomicOrdering::Monotonic); return Builder.saveIP(); @@ -3046,10 +3046,10 @@ xorReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS, } static OpenMPIRBuilder::InsertPointTy -xorAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS) { +xorAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Type *Ty, Value *LHS, + Value *RHS) { IRBuilder<> Builder(IP.getBlock(), IP.getPoint()); - Value *Partial = Builder.CreateLoad(RHS->getType()->getPointerElementType(), - RHS, "red.partial"); + Value *Partial = Builder.CreateLoad(Ty, RHS, "red.partial"); Builder.CreateAtomicRMW(AtomicRMWInst::Xor, LHS, Partial, None, AtomicOrdering::Monotonic); return Builder.saveIP(); @@ -3081,13 +3081,15 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) { // Create variables to be reduced. InsertPointTy OuterAllocaIP(&F->getEntryBlock(), F->getEntryBlock().getFirstInsertionPt()); + Type *SumType = Builder.getFloatTy(); + Type *XorType = Builder.getInt32Ty(); Value *SumReduced; Value *XorReduced; { IRBuilderBase::InsertPointGuard Guard(Builder); Builder.restoreIP(OuterAllocaIP); - SumReduced = Builder.CreateAlloca(Builder.getFloatTy()); - XorReduced = Builder.CreateAlloca(Builder.getInt32Ty()); + SumReduced = Builder.CreateAlloca(SumType); + XorReduced = Builder.CreateAlloca(XorType); } // Store initial values of reductions into global variables. @@ -3109,12 +3111,8 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) { Value *TID = OMPBuilder.getOrCreateThreadID(Ident); Value *SumLocal = Builder.CreateUIToFP(TID, Builder.getFloatTy(), "sum.local"); - Value *SumPartial = - Builder.CreateLoad(SumReduced->getType()->getPointerElementType(), - SumReduced, "sum.partial"); - Value *XorPartial = - Builder.CreateLoad(XorReduced->getType()->getPointerElementType(), - XorReduced, "xor.partial"); + Value *SumPartial = Builder.CreateLoad(SumType, SumReduced, "sum.partial"); + Value *XorPartial = Builder.CreateLoad(XorType, XorReduced, "xor.partial"); Value *Sum = Builder.CreateFAdd(SumPartial, SumLocal, "sum"); Value *Xor = Builder.CreateXor(XorPartial, TID, "xor"); Builder.CreateStore(Sum, SumReduced); @@ -3164,8 +3162,8 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) { Builder.restoreIP(AfterIP); OpenMPIRBuilder::ReductionInfo ReductionInfos[] = { - {SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}, - {XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}}; + {SumType, SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}, + {XorType, XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}}; OMPBuilder.createReductions(BodyIP, BodyAllocaIP, ReductionInfos); @@ -3319,13 +3317,15 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) { // Create variables to be reduced. InsertPointTy OuterAllocaIP(&F->getEntryBlock(), F->getEntryBlock().getFirstInsertionPt()); + Type *SumType = Builder.getFloatTy(); + Type *XorType = Builder.getInt32Ty(); Value *SumReduced; Value *XorReduced; { IRBuilderBase::InsertPointGuard Guard(Builder); Builder.restoreIP(OuterAllocaIP); - SumReduced = Builder.CreateAlloca(Builder.getFloatTy()); - XorReduced = Builder.CreateAlloca(Builder.getInt32Ty()); + SumReduced = Builder.CreateAlloca(SumType); + XorReduced = Builder.CreateAlloca(XorType); } // Store initial values of reductions into global variables. @@ -3344,9 +3344,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) { Value *TID = OMPBuilder.getOrCreateThreadID(Ident); Value *SumLocal = Builder.CreateUIToFP(TID, Builder.getFloatTy(), "sum.local"); - Value *SumPartial = - Builder.CreateLoad(SumReduced->getType()->getPointerElementType(), - SumReduced, "sum.partial"); + Value *SumPartial = Builder.CreateLoad(SumType, SumReduced, "sum.partial"); Value *Sum = Builder.CreateFAdd(SumPartial, SumLocal, "sum"); Builder.CreateStore(Sum, SumReduced); @@ -3364,9 +3362,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) { Constant *SrcLocStr = OMPBuilder.getOrCreateSrcLocStr(Loc); Value *Ident = OMPBuilder.getOrCreateIdent(SrcLocStr); Value *TID = OMPBuilder.getOrCreateThreadID(Ident); - Value *XorPartial = - Builder.CreateLoad(XorReduced->getType()->getPointerElementType(), - XorReduced, "xor.partial"); + Value *XorPartial = Builder.CreateLoad(XorType, XorReduced, "xor.partial"); Value *Xor = Builder.CreateXor(XorPartial, TID, "xor"); Builder.CreateStore(Xor, XorReduced); @@ -3421,10 +3417,10 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) { OMPBuilder.createReductions( FirstBodyIP, FirstBodyAllocaIP, - {{SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}}); + {{SumType, SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}}); OMPBuilder.createReductions( SecondBodyIP, SecondBodyAllocaIP, - {{XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}}); + {{XorType, XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}}); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 96b7644116e8..d8e6ce10b511 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -415,7 +415,8 @@ using OwningReductionGen = std::function; using OwningAtomicReductionGen = std::function; + llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *, + llvm::Value *)>; } // namespace /// Create an OpenMPIRBuilder-compatible reduction generator for the given @@ -462,7 +463,7 @@ makeAtomicReductionGen(omp::ReductionDeclareOp decl, // (which aren't actually mutating it), and we must capture decl by-value to // avoid the dangling reference after the parent function returns. OwningAtomicReductionGen atomicGen = - [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, + [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *, llvm::Value *lhs, llvm::Value *rhs) mutable { Region &atomicRegion = decl.atomicReductionRegion(); moduleTranslation.mapValue(atomicRegion.front().getArgument(0), lhs); @@ -763,9 +764,11 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr; if (owningAtomicReductionGens[i]) atomicGen = owningAtomicReductionGens[i]; - reductionInfos.push_back( - {moduleTranslation.lookupValue(loop.reduction_vars()[i]), - privateReductionVariables[i], owningReductionGens[i], atomicGen}); + llvm::Value *variable = + moduleTranslation.lookupValue(loop.reduction_vars()[i]); + reductionInfos.push_back({variable->getType()->getPointerElementType(), + variable, privateReductionVariables[i], + owningReductionGens[i], atomicGen}); } // The call to createReductions below expects the block to have a