diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index c63a7e5882ef..b40a45807d9c 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -90,11 +90,8 @@ private: ConversionPatternRewriter &rewriter) const { auto type = operand->getType().cast(); - // Reduce elements within each warp to produce the intermediate results. - Value *warpReduce = createWarpReduce(loc, operand, rewriter); - // Create shared memory array to store the warp reduction. - auto module = warpReduce->getDefiningOp()->getParentOfType(); + auto module = operand->getDefiningOp()->getParentOfType(); assert(module && "op must belong to a module"); Value *sharedMemPtr = createSharedMemoryArray(loc, module, type, kWarpSize, rewriter); @@ -105,21 +102,24 @@ private: Value *isFirstLane = rewriter.create( loc, LLVM::ICmpPredicate::eq, laneId, zero); Value *threadIdx = getLinearThreadIndex(loc, rewriter); + Value *blockSize = getBlockSize(loc, rewriter); + Value *activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter); + + // Reduce elements within each warp to produce the intermediate results. + Value *warpReduce = + createWarpReduce(loc, activeWidth, laneId, operand, rewriter); // Write the intermediate results to shared memory, using the first lane of // each warp. - createPredicatedBlock( - loc, isFirstLane, - [&] { - Value *warpId = getDivideByWarpSize(threadIdx, rewriter); - Value *storeDst = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, warpId})); - rewriter.create(loc, warpReduce, storeDst); - }, - rewriter); - + createPredicatedBlock(loc, rewriter, isFirstLane, [&] { + Value *warpId = getDivideByWarpSize(threadIdx, rewriter); + Value *storeDst = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, warpId})); + rewriter.create(loc, warpReduce, storeDst); + }); rewriter.create(loc); - Value *numWarps = getNumWarps(loc, rewriter); + + Value *numWarps = getNumWarps(loc, blockSize, rewriter); Value *isValidWarp = rewriter.create( loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps); Value *resultPtr = rewriter.create( @@ -127,95 +127,156 @@ private: // Use the first numWarps threads to reduce the intermediate results from // shared memory. The final result is written to shared memory again. - createPredicatedBlock( - loc, isValidWarp, - [&] { - Value *loadSrc = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, threadIdx})); - Value *value = rewriter.create(loc, type, loadSrc); - Value *result = createWarpReduce(loc, value, rewriter); - rewriter.create(loc, result, resultPtr); - }, - rewriter); - + createPredicatedBlock(loc, rewriter, isValidWarp, [&] { + Value *loadSrc = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, threadIdx})); + Value *value = rewriter.create(loc, type, loadSrc); + Value *result = createWarpReduce(loc, numWarps, laneId, value, rewriter); + rewriter.create(loc, result, resultPtr); + }); rewriter.create(loc); - Value *result = rewriter.create(loc, type, resultPtr); + // Load and return result from shared memory. + Value *result = rewriter.create(loc, type, resultPtr); return result; } - // Creates an if-block skeleton to perform conditional execution of the - // instructions generated by predicatedOpsFactory. + // Creates an if-block skeleton and calls the two factories to generate the + // ops in the `then` and `else` block.. // // llvm.cond_br %condition, ^then, ^continue // ^then: - // ... code created in `predicatedOpsFactory()` - // llvm.br ^continue - // ^continue: + // %then_operands = `thenOpsFactory()` + // llvm.br ^continue(%then_operands) + // ^else: + // %else_operands = `elseOpsFactory()` + // llvm.br ^continue(%else_operands) + // ^continue(%block_operands): // - template - void createPredicatedBlock(Location loc, Value *condition, - Func &&predicatedOpsFactory, - ConversionPatternRewriter &rewriter) const { + template + void createIf(Location loc, ConversionPatternRewriter &rewriter, + Value *condition, ThenOpsFactory &&thenOpsFactory, + ElseOpsFactory &&elseOpsFactory) const { Block *currentBlock = rewriter.getInsertionBlock(); auto currentPoint = rewriter.getInsertionPoint(); Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); - Block *continueBlock = rewriter.splitBlock(thenBlock, currentPoint); + Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); + Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create( - loc, llvm::makeArrayRef(condition), - ArrayRef{thenBlock, continueBlock}); + rewriter.create(loc, llvm::makeArrayRef(condition), + ArrayRef{thenBlock, elseBlock}); - rewriter.setInsertionPointToEnd(thenBlock); - predicatedOpsFactory(); - rewriter.create(loc, ArrayRef(), - llvm::makeArrayRef(continueBlock)); + auto addBranch = [&](ArrayRef operands) { + rewriter.create(loc, ArrayRef{}, + llvm::makeArrayRef(continueBlock), + llvm::makeArrayRef(operands)); + }; + rewriter.setInsertionPointToStart(thenBlock); + auto thenOperands = thenOpsFactory(); + addBranch(thenOperands); + + rewriter.setInsertionPointToStart(elseBlock); + auto elseOperands = elseOpsFactory(); + addBranch(elseOperands); + + assert(thenOperands.size() == elseOperands.size()); rewriter.setInsertionPointToStart(continueBlock); + for (auto *operand : thenOperands) + continueBlock->addArgument(operand->getType()); } - // Creates an all_reduce across the warp. Creates a preamble - // - // %active_mask = llvm.mlir.constant(-1 : i32) : !llvm.i32 - // %mask_and_clamp = llvm.mlir.constant(31 : i32) : !llvm.i32 - // - // plus the accumulation for i = 1, 2, 4, 8, 16: - // - // %offset = llvm.mlir.constant(i : i32) : !llvm.i32 - // %value = nvvm.shfl.sync.bfly - // %active_mask, %operand, %offset, %mask_and_clamp - // %operand = llvm.fadd %operand, %value - // - // Each thread returns the same result. - // - // Note: this currently only supports reducing exactly 32 values. - Value *createWarpReduce(Location loc, Value *operand, - ConversionPatternRewriter &rewriter) const { - // TODO(csigg): Generalize to partial warps and other types of accumulation. - static_assert(kWarpSize == 32, "Only warp size of 32 is supported."); - auto activeMask = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(~0u)); - auto maskAndClamp = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); + // Shortcut for createIf with empty else block and no block operands. + template + void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter, + Value *condition, + Factory &&predicatedOpsFactory) const { + createIf( + loc, rewriter, condition, + [&] { + predicatedOpsFactory(); + return ArrayRef(); + }, + [&] { return ArrayRef(); }); + } - auto resultType = operand->getType(); - for (int i = 1; i < kWarpSize; i <<= 1) { - auto offset = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(i)); - auto value = rewriter.create( - loc, resultType, activeMask, operand, offset, maskAndClamp); - operand = rewriter.create(loc, resultType, operand, value); - } - return operand; + // Creates a reduction across the first activeWidth lanes of a warp. + // The first lane returns the result, all others return values are undefined. + Value *createWarpReduce(Location loc, Value *activeWidth, Value *laneId, + Value *operand, + ConversionPatternRewriter &rewriter) const { + // TODO(csigg): Generalize to other types of accumulation. + Value *warpSize = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); + Value *maskAndClamp = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); + Value *isPartialWarp = rewriter.create( + loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize); + auto type = operand->getType(); + + createIf( + loc, rewriter, isPartialWarp, + // Generate reduction over a (potentially) partial warp. + [&] { + Value *value = operand; + Value *one = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(1)); + // Bit mask of active lanes: `(1 << activeWidth) - 1`. + Value *activeMask = rewriter.create( + loc, int32Type, + rewriter.create(loc, int32Type, one, activeWidth), + one); + // Bound of offsets which read from a lane within the active range. + Value *offsetBound = + rewriter.create(loc, activeWidth, laneId); + + // Repeatedly shuffle value from 'laneId + i' and accumulate if source + // lane is within the active range. The first lane contains the final + // result, all other lanes contain some undefined partial result. + for (int i = 1; i < kWarpSize; i <<= 1) { + Value *offset = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(i)); + // ShflDownOp instead of ShflBflyOp would produce a scan. ShflBflyOp + // also produces the correct reduction on lane 0 though. + Value *shfl = rewriter.create( + loc, type, activeMask, value, offset, maskAndClamp); + // TODO(csigg): use the second result from the shuffle op instead. + Value *isActiveSrcLane = rewriter.create( + loc, LLVM::ICmpPredicate::slt, offset, offsetBound); + // Skip the accumulation if the shuffle op read from a lane outside + // of the active range. + createIf( + loc, rewriter, isActiveSrcLane, + [&] { + return llvm::SmallVector{ + rewriter.create(loc, type, value, shfl)}; + }, + [&] { return llvm::makeArrayRef(value); }); + value = rewriter.getInsertionBlock()->getArgument(0); + } + return llvm::SmallVector{value}; + }, + // Generate a reduction over the entire warp. This is a specialization + // of the above reduction with unconditional accumulation. + [&] { + Value *value = operand; + Value *activeMask = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(~0u)); + for (int i = 1; i < kWarpSize; i <<= 1) { + Value *offset = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(i)); + Value *shfl = rewriter.create( + loc, type, activeMask, value, offset, maskAndClamp); + value = rewriter.create(loc, type, value, shfl); + } + return llvm::SmallVector{value}; + }); + return rewriter.getInsertionBlock()->getArgument(0); } // Creates a global array stored in shared memory. - // - // llvm.mlir.global @reduce_buffer() - // {addr_space = 3 : i32} : !llvm<"[32 x float]"> - // Value *createSharedMemoryArray(Location loc, ModuleOp module, LLVM::LLVMType elementType, int numElements, ConversionPatternRewriter &rewriter) const { @@ -247,16 +308,6 @@ private: return rewriter.create(loc, int32Type, tmp3, idX); } - // Returns the number of warps in the block. - Value *getNumWarps(Location loc, ConversionPatternRewriter &rewriter) const { - auto blockSize = getBlockSize(loc, rewriter); - auto warpSizeMinusOne = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); - auto biasedBlockSize = rewriter.create( - loc, int32Type, blockSize, warpSizeMinusOne); - return getDivideByWarpSize(biasedBlockSize, rewriter); - } - // Returns the number of threads in the block. Value *getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const { Value *dimX = rewriter.create(loc, int32Type); @@ -266,6 +317,27 @@ private: return rewriter.create(loc, int32Type, dimXY, dimZ); } + // Returns the number of warps in the block. + Value *getNumWarps(Location loc, Value *blockSize, + ConversionPatternRewriter &rewriter) const { + auto warpSizeMinusOne = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); + auto biasedBlockSize = rewriter.create( + loc, int32Type, blockSize, warpSizeMinusOne); + return getDivideByWarpSize(biasedBlockSize, rewriter); + } + + // Returns the number of active threads in the warp, not clamped to 32. + Value *getActiveWidth(Location loc, Value *threadIdx, Value *blockSize, + ConversionPatternRewriter &rewriter) const { + Value *threadIdxMask = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1))); + Value *numThreadsWithSmallerWarpId = + rewriter.create(loc, threadIdx, threadIdxMask); + return rewriter.create(loc, blockSize, + numThreadsWithSmallerWarpId); + } + // Returns value divided by the warp size (i.e. 32). Value *getDivideByWarpSize(Value *value, ConversionPatternRewriter &rewriter) const { @@ -277,7 +349,6 @@ private: LLVM::LLVMType int32Type; - // TODO(csigg): Support other warp sizes. static constexpr int kWarpSize = 32; }; diff --git a/mlir/test/mlir-cuda-runner/all-reduce.mlir b/mlir/test/mlir-cuda-runner/all-reduce.mlir index d607870cc7a3..784447cbbb16 100644 --- a/mlir/test/mlir-cuda-runner/all-reduce.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce.mlir @@ -1,9 +1,9 @@ // RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext --entry-point-result=void | FileCheck %s -// CHECK: [8.128000e+03, 8.128000e+03, {{.*}}, 8.128000e+03, 8.128000e+03] +// CHECK: [5.356000e+03, 5.356000e+03, {{.*}}, 5.356000e+03, 5.356000e+03] func @main() { - %arg = alloc() : memref<16x4x2xf32> - %dst = memref_cast %arg : memref<16x4x2xf32> to memref + %arg = alloc() : memref<13x4x2xf32> + %dst = memref_cast %arg : memref<13x4x2xf32> to memref %zero = constant 0 : i32 %one = constant 1 : index %sx = dim %dst, 0 : memref