[OpenMPIRBuilder] Support opaque pointers in reduction handling

Make the reduction handling in OpenMPIRBuilder compatible with
opaque pointers by explicitly storing the element type in ReductionInfo,
and also passing it to the atomic reduction callback, as at least
the ones in the test need the type there.

This doesn't make things fully compatible yet, there are other
uses of element types in this class. I also left one
getPointerElementType() call in mlir, because I'm not familiar
with that area.

Differential Revison: https://reviews.llvm.org/D115638
This commit is contained in:
Nikita Popov 2021-12-13 16:23:15 +01:00
parent 26f6fbe2be
commit d733f2c68c
4 changed files with 48 additions and 46 deletions

View File

@ -539,24 +539,27 @@ public:
function_ref<InsertPointTy(InsertPointTy, Value *, Value *, Value *&)>;
/// Functions used to generate atomic reductions. Such functions take two
/// Values representing pointers to LHS and RHS of the reduction. They are
/// expected to atomically update the LHS to the reduced value.
/// Values representing pointers to LHS and RHS of the reduction, as well as
/// the element type of these pointers. They are expected to atomically
/// update the LHS to the reduced value.
using AtomicReductionGenTy =
function_ref<InsertPointTy(InsertPointTy, Value *, Value *)>;
function_ref<InsertPointTy(InsertPointTy, Type *, Value *, Value *)>;
/// Information about an OpenMP reduction.
struct ReductionInfo {
ReductionInfo(Value *Variable, Value *PrivateVariable,
ReductionInfo(Type *ElementType, Value *Variable, Value *PrivateVariable,
ReductionGenTy ReductionGen,
AtomicReductionGenTy AtomicReductionGen)
: Variable(Variable), PrivateVariable(PrivateVariable),
ReductionGen(ReductionGen), AtomicReductionGen(AtomicReductionGen) {}
/// Returns the type of the element being reduced.
Type *getElementType() const {
return Variable->getType()->getPointerElementType();
: ElementType(ElementType), Variable(Variable),
PrivateVariable(PrivateVariable), ReductionGen(ReductionGen),
AtomicReductionGen(AtomicReductionGen) {
assert(cast<PointerType>(Variable->getType())
->isOpaqueOrPointeeTypeMatches(ElementType) && "Invalid elem type");
}
/// Reduction element type, must match pointee type of variable.
Type *ElementType;
/// Reduction variable of pointer type.
Value *Variable;

View File

@ -1156,7 +1156,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
Builder.SetInsertPoint(NonAtomicRedBlock);
for (auto En : enumerate(ReductionInfos)) {
const ReductionInfo &RI = En.value();
Type *ValueType = RI.getElementType();
Type *ValueType = RI.ElementType;
Value *RedValue = Builder.CreateLoad(ValueType, RI.Variable,
"red.value." + Twine(En.index()));
Value *PrivateRedValue =
@ -1181,8 +1181,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
Builder.SetInsertPoint(AtomicRedBlock);
if (CanGenerateAtomic) {
for (const ReductionInfo &RI : ReductionInfos) {
Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.Variable,
RI.PrivateVariable));
Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType,
RI.Variable, RI.PrivateVariable));
if (!Builder.GetInsertBlock())
return InsertPointTy();
}
@ -1207,13 +1207,13 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
RedArrayTy, LHSArrayPtr, 0, En.index());
Value *LHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), LHSI8PtrPtr);
Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
Value *LHS = Builder.CreateLoad(RI.getElementType(), LHSPtr);
Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
RedArrayTy, RHSArrayPtr, 0, En.index());
Value *RHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), RHSI8PtrPtr);
Value *RHSPtr =
Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
Value *RHS = Builder.CreateLoad(RI.getElementType(), RHSPtr);
Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
Value *Reduced;
Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
if (!Builder.GetInsertBlock())

View File

@ -3028,10 +3028,10 @@ sumReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS,
}
static OpenMPIRBuilder::InsertPointTy
sumAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS) {
sumAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Type *Ty, Value *LHS,
Value *RHS) {
IRBuilder<> Builder(IP.getBlock(), IP.getPoint());
Value *Partial = Builder.CreateLoad(RHS->getType()->getPointerElementType(),
RHS, "red.partial");
Value *Partial = Builder.CreateLoad(Ty, RHS, "red.partial");
Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, LHS, Partial, None,
AtomicOrdering::Monotonic);
return Builder.saveIP();
@ -3046,10 +3046,10 @@ xorReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS,
}
static OpenMPIRBuilder::InsertPointTy
xorAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS) {
xorAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Type *Ty, Value *LHS,
Value *RHS) {
IRBuilder<> Builder(IP.getBlock(), IP.getPoint());
Value *Partial = Builder.CreateLoad(RHS->getType()->getPointerElementType(),
RHS, "red.partial");
Value *Partial = Builder.CreateLoad(Ty, RHS, "red.partial");
Builder.CreateAtomicRMW(AtomicRMWInst::Xor, LHS, Partial, None,
AtomicOrdering::Monotonic);
return Builder.saveIP();
@ -3081,13 +3081,15 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) {
// Create variables to be reduced.
InsertPointTy OuterAllocaIP(&F->getEntryBlock(),
F->getEntryBlock().getFirstInsertionPt());
Type *SumType = Builder.getFloatTy();
Type *XorType = Builder.getInt32Ty();
Value *SumReduced;
Value *XorReduced;
{
IRBuilderBase::InsertPointGuard Guard(Builder);
Builder.restoreIP(OuterAllocaIP);
SumReduced = Builder.CreateAlloca(Builder.getFloatTy());
XorReduced = Builder.CreateAlloca(Builder.getInt32Ty());
SumReduced = Builder.CreateAlloca(SumType);
XorReduced = Builder.CreateAlloca(XorType);
}
// Store initial values of reductions into global variables.
@ -3109,12 +3111,8 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) {
Value *TID = OMPBuilder.getOrCreateThreadID(Ident);
Value *SumLocal =
Builder.CreateUIToFP(TID, Builder.getFloatTy(), "sum.local");
Value *SumPartial =
Builder.CreateLoad(SumReduced->getType()->getPointerElementType(),
SumReduced, "sum.partial");
Value *XorPartial =
Builder.CreateLoad(XorReduced->getType()->getPointerElementType(),
XorReduced, "xor.partial");
Value *SumPartial = Builder.CreateLoad(SumType, SumReduced, "sum.partial");
Value *XorPartial = Builder.CreateLoad(XorType, XorReduced, "xor.partial");
Value *Sum = Builder.CreateFAdd(SumPartial, SumLocal, "sum");
Value *Xor = Builder.CreateXor(XorPartial, TID, "xor");
Builder.CreateStore(Sum, SumReduced);
@ -3164,8 +3162,8 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) {
Builder.restoreIP(AfterIP);
OpenMPIRBuilder::ReductionInfo ReductionInfos[] = {
{SumReduced, SumPrivatized, sumReduction, sumAtomicReduction},
{XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}};
{SumType, SumReduced, SumPrivatized, sumReduction, sumAtomicReduction},
{XorType, XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}};
OMPBuilder.createReductions(BodyIP, BodyAllocaIP, ReductionInfos);
@ -3319,13 +3317,15 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {
// Create variables to be reduced.
InsertPointTy OuterAllocaIP(&F->getEntryBlock(),
F->getEntryBlock().getFirstInsertionPt());
Type *SumType = Builder.getFloatTy();
Type *XorType = Builder.getInt32Ty();
Value *SumReduced;
Value *XorReduced;
{
IRBuilderBase::InsertPointGuard Guard(Builder);
Builder.restoreIP(OuterAllocaIP);
SumReduced = Builder.CreateAlloca(Builder.getFloatTy());
XorReduced = Builder.CreateAlloca(Builder.getInt32Ty());
SumReduced = Builder.CreateAlloca(SumType);
XorReduced = Builder.CreateAlloca(XorType);
}
// Store initial values of reductions into global variables.
@ -3344,9 +3344,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {
Value *TID = OMPBuilder.getOrCreateThreadID(Ident);
Value *SumLocal =
Builder.CreateUIToFP(TID, Builder.getFloatTy(), "sum.local");
Value *SumPartial =
Builder.CreateLoad(SumReduced->getType()->getPointerElementType(),
SumReduced, "sum.partial");
Value *SumPartial = Builder.CreateLoad(SumType, SumReduced, "sum.partial");
Value *Sum = Builder.CreateFAdd(SumPartial, SumLocal, "sum");
Builder.CreateStore(Sum, SumReduced);
@ -3364,9 +3362,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {
Constant *SrcLocStr = OMPBuilder.getOrCreateSrcLocStr(Loc);
Value *Ident = OMPBuilder.getOrCreateIdent(SrcLocStr);
Value *TID = OMPBuilder.getOrCreateThreadID(Ident);
Value *XorPartial =
Builder.CreateLoad(XorReduced->getType()->getPointerElementType(),
XorReduced, "xor.partial");
Value *XorPartial = Builder.CreateLoad(XorType, XorReduced, "xor.partial");
Value *Xor = Builder.CreateXor(XorPartial, TID, "xor");
Builder.CreateStore(Xor, XorReduced);
@ -3421,10 +3417,10 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {
OMPBuilder.createReductions(
FirstBodyIP, FirstBodyAllocaIP,
{{SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}});
{{SumType, SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}});
OMPBuilder.createReductions(
SecondBodyIP, SecondBodyAllocaIP,
{{XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}});
{{XorType, XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}});
Builder.restoreIP(AfterIP);
Builder.CreateRetVoid();

View File

@ -415,7 +415,8 @@ using OwningReductionGen = std::function<llvm::OpenMPIRBuilder::InsertPointTy(
llvm::Value *&)>;
using OwningAtomicReductionGen =
std::function<llvm::OpenMPIRBuilder::InsertPointTy(
llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *)>;
llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
llvm::Value *)>;
} // namespace
/// Create an OpenMPIRBuilder-compatible reduction generator for the given
@ -462,7 +463,7 @@ makeAtomicReductionGen(omp::ReductionDeclareOp decl,
// (which aren't actually mutating it), and we must capture decl by-value to
// avoid the dangling reference after the parent function returns.
OwningAtomicReductionGen atomicGen =
[&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
[&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
llvm::Value *lhs, llvm::Value *rhs) mutable {
Region &atomicRegion = decl.atomicReductionRegion();
moduleTranslation.mapValue(atomicRegion.front().getArgument(0), lhs);
@ -763,9 +764,11 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr;
if (owningAtomicReductionGens[i])
atomicGen = owningAtomicReductionGens[i];
reductionInfos.push_back(
{moduleTranslation.lookupValue(loop.reduction_vars()[i]),
privateReductionVariables[i], owningReductionGens[i], atomicGen});
llvm::Value *variable =
moduleTranslation.lookupValue(loop.reduction_vars()[i]);
reductionInfos.push_back({variable->getType()->getPointerElementType(),
variable, privateReductionVariables[i],
owningReductionGens[i], atomicGen});
}
// The call to createReductions below expects the block to have a