From d90b7bf2c53d0315a13a81904862929252bb6824 Mon Sep 17 00:00:00 2001 From: Dominik Adamski Date: Thu, 28 Jul 2022 03:57:40 -0500 Subject: [PATCH] Add support for lowering simd if clause to LLVM IR Scope of changes: 1) Added new function to generate loop versioning 2) Added support for if clause to applySimd function 2) Added tests which confirm that lowering is successful If ifCond is specified, then collapsed loop is duplicated and if branch is added. Duplicated loop is executed if simd ifCond is evaluated to false. Reviewed By: Meinersbur Differential Revision: https://reviews.llvm.org/D129368 Signed-off-by: Dominik Adamski --- clang/lib/CodeGen/CGStmtOpenMP.cpp | 4 +- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 23 +++- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 130 +++++++++++++++--- .../Frontend/OpenMPIRBuilderTest.cpp | 53 ++++++- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 10 +- mlir/test/Target/LLVMIR/openmp-llvm.mlir | 28 ++++ 6 files changed, 218 insertions(+), 30 deletions(-) diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp index aa55cdaca5dc..962620f43a39 100644 --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -2646,7 +2646,9 @@ void CodeGenFunction::EmitOMPSimdDirective(const OMPSimdDirective &S) { auto *Val = cast(Len.getScalarVal()); Simdlen = Val; } - OMPBuilder.applySimd(CLI, Simdlen); + // Add simd metadata to the collapsed loop. Do not generate + // another loop for if clause. Support for if clause is done earlier. + OMPBuilder.applySimd(CLI, /*IfCond*/ nullptr, Simdlen); return; } }; diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 40ca2da4c911..5ae9baab0e5d 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -14,6 +14,7 @@ #ifndef LLVM_FRONTEND_OPENMP_OMPIRBUILDER_H #define LLVM_FRONTEND_OPENMP_OMPIRBUILDER_H +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/IRBuilder.h" @@ -467,6 +468,20 @@ private: bool NeedsBarrier, Value *Chunk = nullptr); + /// Create alternative version of the loop to support if clause + /// + /// OpenMP if clause can require to generate second loop. This loop + /// will be executed when if clause condition is not met. createIfVersion + /// adds branch instruction to the copied loop if \p ifCond is not met. + /// + /// \param Loop Original loop which should be versioned. + /// \param IfCond Value which corresponds to if clause condition + /// \param VMap Value to value map to define relation between + /// original and copied loop values and loop blocks. + /// \param NamePrefix Optional name prefix for if.then if.else blocks. + void createIfVersion(CanonicalLoopInfo *Loop, Value *IfCond, + ValueToValueMapTy &VMap, const Twine &NamePrefix = ""); + public: /// Modifies the canonical loop to be a workshare loop. /// @@ -597,11 +612,15 @@ public: void unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop, int32_t Factor, CanonicalLoopInfo **UnrolledCLI); - /// Add metadata to simd-ize a loop. + /// Add metadata to simd-ize a loop. If IfCond is not nullptr, the loop + /// is cloned. The metadata which prevents vectorization is added to + /// to the cloned loop. The cloned loop is executed when ifCond is evaluated + /// to false. /// /// \param Loop The loop to simd-ize. + /// \param IfCond The value which corresponds to the if clause condition. /// \param Simdlen The Simdlen length to apply to the simd loop. - void applySimd(CanonicalLoopInfo *Loop, ConstantInt *Simdlen); + void applySimd(CanonicalLoopInfo *Loop, Value *IfCond, ConstantInt *Simdlen); /// Generator for '#omp flush' /// diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index cee4cddab5e8..736976d40643 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -34,6 +34,7 @@ #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/CodeExtractor.h" #include "llvm/Transforms/Utils/LoopPeel.h" #include "llvm/Transforms/Utils/UnrollLoop.h" @@ -2839,32 +2840,40 @@ OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef Loops, return Result; } +/// Attach metadata \p Properties to the basic block described by \p BB. If the +/// basic block already has metadata, the basic block properties are appended. +static void addBasicBlockMetadata(BasicBlock *BB, + ArrayRef Properties) { + // Nothing to do if no property to attach. + if (Properties.empty()) + return; + + LLVMContext &Ctx = BB->getContext(); + SmallVector NewProperties; + NewProperties.push_back(nullptr); + + // If the basic block already has metadata, prepend it to the new metadata. + MDNode *Existing = BB->getTerminator()->getMetadata(LLVMContext::MD_loop); + if (Existing) + append_range(NewProperties, drop_begin(Existing->operands(), 1)); + + append_range(NewProperties, Properties); + MDNode *BasicBlockID = MDNode::getDistinct(Ctx, NewProperties); + BasicBlockID->replaceOperandWith(0, BasicBlockID); + + BB->getTerminator()->setMetadata(LLVMContext::MD_loop, BasicBlockID); +} + /// Attach loop metadata \p Properties to the loop described by \p Loop. If the /// loop already has metadata, the loop properties are appended. static void addLoopMetadata(CanonicalLoopInfo *Loop, ArrayRef Properties) { assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo"); - // Nothing to do if no property to attach. - if (Properties.empty()) - return; - - LLVMContext &Ctx = Loop->getFunction()->getContext(); - SmallVector NewLoopProperties; - NewLoopProperties.push_back(nullptr); - - // If the loop already has metadata, prepend it to the new metadata. + // Attach metadata to the loop's latch BasicBlock *Latch = Loop->getLatch(); assert(Latch && "A valid CanonicalLoopInfo must have a unique latch"); - MDNode *Existing = Latch->getTerminator()->getMetadata(LLVMContext::MD_loop); - if (Existing) - append_range(NewLoopProperties, drop_begin(Existing->operands(), 1)); - - append_range(NewLoopProperties, Properties); - MDNode *LoopID = MDNode::getDistinct(Ctx, NewLoopProperties); - LoopID->replaceOperandWith(0, LoopID); - - Latch->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID); + addBasicBlockMetadata(Latch, Properties); } /// Attach llvm.access.group metadata to the memref instructions of \p Block @@ -2895,12 +2904,77 @@ void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) { }); } -void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop, +void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop, + Value *IfCond, ValueToValueMapTy &VMap, + const Twine &NamePrefix) { + Function *F = CanonicalLoop->getFunction(); + + // Define where if branch should be inserted + Instruction *SplitBefore; + if (Instruction::classof(IfCond)) { + SplitBefore = dyn_cast(IfCond); + } else { + SplitBefore = CanonicalLoop->getPreheader()->getTerminator(); + } + + // TODO: We should not rely on pass manager. Currently we use pass manager + // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo + // object. We should have a method which returns all blocks between + // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter() + FunctionAnalysisManager FAM; + FAM.registerPass([]() { return DominatorTreeAnalysis(); }); + FAM.registerPass([]() { return LoopAnalysis(); }); + FAM.registerPass([]() { return PassInstrumentationAnalysis(); }); + + // Get the loop which needs to be cloned + LoopAnalysis LIA; + LoopInfo &&LI = LIA.run(*F, FAM); + Loop *L = LI.getLoopFor(CanonicalLoop->getHeader()); + + // Create additional blocks for the if statement + BasicBlock *Head = SplitBefore->getParent(); + Instruction *HeadOldTerm = Head->getTerminator(); + llvm::LLVMContext &C = Head->getContext(); + llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create( + C, NamePrefix + ".if.then", Head->getParent(), Head->getNextNode()); + llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create( + C, NamePrefix + ".if.else", Head->getParent(), CanonicalLoop->getExit()); + + // Create if condition branch. + Builder.SetInsertPoint(HeadOldTerm); + Instruction *BrInstr = + Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock); + InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()}; + // Then block contains branch to omp loop which needs to be vectorized + spliceBB(IP, ThenBlock, false); + ThenBlock->replaceSuccessorsPhiUsesWith(Head, ThenBlock); + + Builder.SetInsertPoint(ElseBlock); + + // Clone loop for the else branch + SmallVector NewBlocks; + + VMap[CanonicalLoop->getPreheader()] = ElseBlock; + for (BasicBlock *Block : L->getBlocks()) { + BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F); + NewBB->moveBefore(CanonicalLoop->getExit()); + VMap[Block] = NewBB; + NewBlocks.push_back(NewBB); + } + remapInstructionsInBlocks(NewBlocks, VMap); + Builder.CreateBr(NewBlocks.front()); +} + +void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop, Value *IfCond, ConstantInt *Simdlen) { LLVMContext &Ctx = Builder.getContext(); Function *F = CanonicalLoop->getFunction(); + // TODO: We should not rely on pass manager. Currently we use pass manager + // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo + // object. We should have a method which returns all blocks between + // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter() FunctionAnalysisManager FAM; FAM.registerPass([]() { return DominatorTreeAnalysis(); }); FAM.registerPass([]() { return LoopAnalysis(); }); @@ -2911,6 +2985,24 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop, Loop *L = LI.getLoopFor(CanonicalLoop->getHeader()); + if (IfCond) { + ValueToValueMapTy VMap; + createIfVersion(CanonicalLoop, IfCond, VMap, "simd"); + // Add metadata to the cloned loop which disables vectorization + Value *MappedLatch = VMap.lookup(CanonicalLoop->getLatch()); + assert(MappedLatch && + "Cannot find value which corresponds to original loop latch"); + assert(isa(MappedLatch) && + "Cannot cast mapped latch block value to BasicBlock"); + BasicBlock *NewLatchBlock = dyn_cast(MappedLatch); + ConstantAsMetadata *BoolConst = + ConstantAsMetadata::get(ConstantInt::getFalse(Type::getInt1Ty(Ctx))); + addBasicBlockMetadata( + NewLatchBlock, + {MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"), + BoolConst})}); + } + SmallSet Reachable; // Get the basic blocks from the loop in which memref instructions diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 55afea6e89d5..7e3b5481e7bd 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -1771,7 +1771,7 @@ TEST_F(OpenMPIRBuilderTest, ApplySimd) { CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32); // Simd-ize the loop. - OMPBuilder.applySimd(CLI, nullptr); + OMPBuilder.applySimd(CLI, /* IfCond */ nullptr, /* Simdlen */ nullptr); OMPBuilder.finalize(); EXPECT_FALSE(verifyModule(*M, &errs())); @@ -1802,7 +1802,8 @@ TEST_F(OpenMPIRBuilderTest, ApplySimdlen) { CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32); // Simd-ize the loop. - OMPBuilder.applySimd(CLI, ConstantInt::get(Type::getInt32Ty(Ctx), 3)); + OMPBuilder.applySimd(CLI, /*IfCond */ nullptr, + ConstantInt::get(Type::getInt32Ty(Ctx), 3)); OMPBuilder.finalize(); EXPECT_FALSE(verifyModule(*M, &errs())); @@ -1828,6 +1829,54 @@ TEST_F(OpenMPIRBuilderTest, ApplySimdlen) { })); } +TEST_F(OpenMPIRBuilderTest, ApplySimdLoopIf) { + OpenMPIRBuilder OMPBuilder(*M); + IRBuilder<> Builder(BB); + AllocaInst *Alloc1 = Builder.CreateAlloca(Builder.getInt32Ty()); + AllocaInst *Alloc2 = Builder.CreateAlloca(Builder.getInt32Ty()); + + // Generation of if condition + Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0U), Alloc1); + Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 1U), Alloc2); + LoadInst *Load1 = Builder.CreateLoad(Alloc1->getAllocatedType(), Alloc1); + LoadInst *Load2 = Builder.CreateLoad(Alloc2->getAllocatedType(), Alloc2); + + Value *IfCmp = Builder.CreateICmpNE(Load1, Load2); + + CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32); + + // Simd-ize the loop with if condition + OMPBuilder.applySimd(CLI, IfCmp, ConstantInt::get(Type::getInt32Ty(Ctx), 3)); + + OMPBuilder.finalize(); + EXPECT_FALSE(verifyModule(*M, &errs())); + + PassBuilder PB; + FunctionAnalysisManager FAM; + PB.registerFunctionAnalyses(FAM); + LoopInfo &LI = FAM.getResult(*F); + + // Check if there are two loops (one with enabled vectorization) + const std::vector &TopLvl = LI.getTopLevelLoops(); + EXPECT_EQ(TopLvl.size(), 2u); + + Loop *L = TopLvl[0]; + EXPECT_TRUE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses")); + EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable")); + EXPECT_EQ(getIntLoopAttribute(L, "llvm.loop.vectorize.width"), 3); + + // The second loop should have disabled vectorization + L = TopLvl[1]; + EXPECT_FALSE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses")); + EXPECT_FALSE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable")); + // Check for llvm.access.group metadata attached to the printf + // function in the loop body. + BasicBlock *LoopBody = CLI->getBody(); + EXPECT_TRUE(any_of(*LoopBody, [](Instruction &I) { + return I.getMetadata("llvm.access.group") != nullptr; + })); +} + TEST_F(OpenMPIRBuilderTest, UnrollLoopFull) { OpenMPIRBuilder OMPBuilder(*M); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 34c532311fa4..85ec47aae400 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -912,11 +912,6 @@ convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder, SmallVector loopInfos; SmallVector bodyInsertPoints; LogicalResult bodyGenStatus = success(); - - // TODO: The code generation for if clause is not supported yet. - if (loop.if_expr()) - return failure(); - auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) { // Make sure further conversions know about the induction variable. moduleTranslation.mapValue( @@ -975,7 +970,10 @@ convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder, if (llvm::Optional simdlenVar = loop.simdlen()) simdlen = builder.getInt64(simdlenVar.value()); - ompBuilder->applySimd(loopInfo, simdlen); + ompBuilder->applySimd( + loopInfo, + loop.if_expr() ? moduleTranslation.lookupValue(loop.if_expr()) : nullptr, + simdlen); builder.restoreIP(afterIP); return success(); diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 4b36ed03b549..a3e6aa2d9b8b 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -750,6 +750,34 @@ llvm.func @simdloop_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64 // ----- +// CHECK-LABEL: @simdloop_if +llvm.func @simdloop_if(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fir.bindc_name = "threshold"}) { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x i32 {adapt.valuebyref, in_type = i32, operand_segment_sizes = dense<0> : vector<2xi32>} : (i64) -> !llvm.ptr + %2 = llvm.mlir.constant(1 : i64) : i64 + %3 = llvm.alloca %2 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = dense<0> : vector<2xi32>, uniq_name = "_QFtest_simdEi"} : (i64) -> !llvm.ptr + %4 = llvm.mlir.constant(0 : i32) : i32 + %5 = llvm.load %arg0 : !llvm.ptr + %6 = llvm.mlir.constant(1 : i32) : i32 + %7 = llvm.load %arg0 : !llvm.ptr + %8 = llvm.load %arg1 : !llvm.ptr + %9 = llvm.icmp "sge" %7, %8 : i32 + omp.simdloop if(%9) for (%arg2) : i32 = (%4) to (%5) inclusive step (%6) { + // The form of the emitted IR is controlled by OpenMPIRBuilder and + // tested there. Just check that the right metadata is added. + // CHECK: llvm.access.group + llvm.store %arg2, %1 : !llvm.ptr + omp.yield + } + llvm.return +} +// Be sure that llvm.loop.vectorize.enable metadata appears twice +// CHECK: llvm.loop.parallel_accesses +// CHECK-NEXT: llvm.loop.vectorize.enable +// CHECK: llvm.loop.vectorize.enable + +// ----- + llvm.func @body(i64) llvm.func @test_omp_wsloop_ordered(%lb : i64, %ub : i64, %step : i64) -> () {