forked from OSchip/llvm-project
Support reduction of partial warps.
gpu.all_reduce now supports block sizes that are not multiple of 32. PiperOrigin-RevId: 273255204
This commit is contained in:
parent
77672c9777
commit
7c765d97f9
|
@ -90,11 +90,8 @@ private:
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
auto type = operand->getType().cast<LLVM::LLVMType>();
|
||||
|
||||
// Reduce elements within each warp to produce the intermediate results.
|
||||
Value *warpReduce = createWarpReduce(loc, operand, rewriter);
|
||||
|
||||
// Create shared memory array to store the warp reduction.
|
||||
auto module = warpReduce->getDefiningOp()->getParentOfType<ModuleOp>();
|
||||
auto module = operand->getDefiningOp()->getParentOfType<ModuleOp>();
|
||||
assert(module && "op must belong to a module");
|
||||
Value *sharedMemPtr =
|
||||
createSharedMemoryArray(loc, module, type, kWarpSize, rewriter);
|
||||
|
@ -105,21 +102,24 @@ private:
|
|||
Value *isFirstLane = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::ICmpPredicate::eq, laneId, zero);
|
||||
Value *threadIdx = getLinearThreadIndex(loc, rewriter);
|
||||
Value *blockSize = getBlockSize(loc, rewriter);
|
||||
Value *activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter);
|
||||
|
||||
// Reduce elements within each warp to produce the intermediate results.
|
||||
Value *warpReduce =
|
||||
createWarpReduce(loc, activeWidth, laneId, operand, rewriter);
|
||||
|
||||
// Write the intermediate results to shared memory, using the first lane of
|
||||
// each warp.
|
||||
createPredicatedBlock(
|
||||
loc, isFirstLane,
|
||||
[&] {
|
||||
Value *warpId = getDivideByWarpSize(threadIdx, rewriter);
|
||||
Value *storeDst = rewriter.create<LLVM::GEPOp>(
|
||||
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, warpId}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
|
||||
},
|
||||
rewriter);
|
||||
|
||||
createPredicatedBlock(loc, rewriter, isFirstLane, [&] {
|
||||
Value *warpId = getDivideByWarpSize(threadIdx, rewriter);
|
||||
Value *storeDst = rewriter.create<LLVM::GEPOp>(
|
||||
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, warpId}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
|
||||
});
|
||||
rewriter.create<NVVM::Barrier0Op>(loc);
|
||||
Value *numWarps = getNumWarps(loc, rewriter);
|
||||
|
||||
Value *numWarps = getNumWarps(loc, blockSize, rewriter);
|
||||
Value *isValidWarp = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps);
|
||||
Value *resultPtr = rewriter.create<LLVM::GEPOp>(
|
||||
|
@ -127,95 +127,156 @@ private:
|
|||
|
||||
// Use the first numWarps threads to reduce the intermediate results from
|
||||
// shared memory. The final result is written to shared memory again.
|
||||
createPredicatedBlock(
|
||||
loc, isValidWarp,
|
||||
[&] {
|
||||
Value *loadSrc = rewriter.create<LLVM::GEPOp>(
|
||||
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, threadIdx}));
|
||||
Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
|
||||
Value *result = createWarpReduce(loc, value, rewriter);
|
||||
rewriter.create<LLVM::StoreOp>(loc, result, resultPtr);
|
||||
},
|
||||
rewriter);
|
||||
|
||||
createPredicatedBlock(loc, rewriter, isValidWarp, [&] {
|
||||
Value *loadSrc = rewriter.create<LLVM::GEPOp>(
|
||||
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, threadIdx}));
|
||||
Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
|
||||
Value *result = createWarpReduce(loc, numWarps, laneId, value, rewriter);
|
||||
rewriter.create<LLVM::StoreOp>(loc, result, resultPtr);
|
||||
});
|
||||
rewriter.create<NVVM::Barrier0Op>(loc);
|
||||
Value *result = rewriter.create<LLVM::LoadOp>(loc, type, resultPtr);
|
||||
|
||||
// Load and return result from shared memory.
|
||||
Value *result = rewriter.create<LLVM::LoadOp>(loc, type, resultPtr);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Creates an if-block skeleton to perform conditional execution of the
|
||||
// instructions generated by predicatedOpsFactory.
|
||||
// Creates an if-block skeleton and calls the two factories to generate the
|
||||
// ops in the `then` and `else` block..
|
||||
//
|
||||
// llvm.cond_br %condition, ^then, ^continue
|
||||
// ^then:
|
||||
// ... code created in `predicatedOpsFactory()`
|
||||
// llvm.br ^continue
|
||||
// ^continue:
|
||||
// %then_operands = `thenOpsFactory()`
|
||||
// llvm.br ^continue(%then_operands)
|
||||
// ^else:
|
||||
// %else_operands = `elseOpsFactory()`
|
||||
// llvm.br ^continue(%else_operands)
|
||||
// ^continue(%block_operands):
|
||||
//
|
||||
template <typename Func>
|
||||
void createPredicatedBlock(Location loc, Value *condition,
|
||||
Func &&predicatedOpsFactory,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
template <typename ThenOpsFactory, typename ElseOpsFactory>
|
||||
void createIf(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value *condition, ThenOpsFactory &&thenOpsFactory,
|
||||
ElseOpsFactory &&elseOpsFactory) const {
|
||||
Block *currentBlock = rewriter.getInsertionBlock();
|
||||
auto currentPoint = rewriter.getInsertionPoint();
|
||||
|
||||
Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
|
||||
Block *continueBlock = rewriter.splitBlock(thenBlock, currentPoint);
|
||||
Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
|
||||
Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
|
||||
|
||||
rewriter.setInsertionPointToEnd(currentBlock);
|
||||
rewriter.create<LLVM::CondBrOp>(
|
||||
loc, llvm::makeArrayRef(condition),
|
||||
ArrayRef<Block *>{thenBlock, continueBlock});
|
||||
rewriter.create<LLVM::CondBrOp>(loc, llvm::makeArrayRef(condition),
|
||||
ArrayRef<Block *>{thenBlock, elseBlock});
|
||||
|
||||
rewriter.setInsertionPointToEnd(thenBlock);
|
||||
predicatedOpsFactory();
|
||||
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>(),
|
||||
llvm::makeArrayRef(continueBlock));
|
||||
auto addBranch = [&](ArrayRef<Value *> operands) {
|
||||
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>{},
|
||||
llvm::makeArrayRef(continueBlock),
|
||||
llvm::makeArrayRef(operands));
|
||||
};
|
||||
|
||||
rewriter.setInsertionPointToStart(thenBlock);
|
||||
auto thenOperands = thenOpsFactory();
|
||||
addBranch(thenOperands);
|
||||
|
||||
rewriter.setInsertionPointToStart(elseBlock);
|
||||
auto elseOperands = elseOpsFactory();
|
||||
addBranch(elseOperands);
|
||||
|
||||
assert(thenOperands.size() == elseOperands.size());
|
||||
rewriter.setInsertionPointToStart(continueBlock);
|
||||
for (auto *operand : thenOperands)
|
||||
continueBlock->addArgument(operand->getType());
|
||||
}
|
||||
|
||||
// Creates an all_reduce across the warp. Creates a preamble
|
||||
//
|
||||
// %active_mask = llvm.mlir.constant(-1 : i32) : !llvm.i32
|
||||
// %mask_and_clamp = llvm.mlir.constant(31 : i32) : !llvm.i32
|
||||
//
|
||||
// plus the accumulation for i = 1, 2, 4, 8, 16:
|
||||
//
|
||||
// %offset = llvm.mlir.constant(i : i32) : !llvm.i32
|
||||
// %value = nvvm.shfl.sync.bfly
|
||||
// %active_mask, %operand, %offset, %mask_and_clamp
|
||||
// %operand = llvm.fadd %operand, %value
|
||||
//
|
||||
// Each thread returns the same result.
|
||||
//
|
||||
// Note: this currently only supports reducing exactly 32 values.
|
||||
Value *createWarpReduce(Location loc, Value *operand,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// TODO(csigg): Generalize to partial warps and other types of accumulation.
|
||||
static_assert(kWarpSize == 32, "Only warp size of 32 is supported.");
|
||||
auto activeMask = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(~0u));
|
||||
auto maskAndClamp = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
|
||||
// Shortcut for createIf with empty else block and no block operands.
|
||||
template <typename Factory>
|
||||
void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value *condition,
|
||||
Factory &&predicatedOpsFactory) const {
|
||||
createIf(
|
||||
loc, rewriter, condition,
|
||||
[&] {
|
||||
predicatedOpsFactory();
|
||||
return ArrayRef<Value *>();
|
||||
},
|
||||
[&] { return ArrayRef<Value *>(); });
|
||||
}
|
||||
|
||||
auto resultType = operand->getType();
|
||||
for (int i = 1; i < kWarpSize; i <<= 1) {
|
||||
auto offset = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(i));
|
||||
auto value = rewriter.create<NVVM::ShflBflyOp>(
|
||||
loc, resultType, activeMask, operand, offset, maskAndClamp);
|
||||
operand = rewriter.create<LLVM::FAddOp>(loc, resultType, operand, value);
|
||||
}
|
||||
return operand;
|
||||
// Creates a reduction across the first activeWidth lanes of a warp.
|
||||
// The first lane returns the result, all others return values are undefined.
|
||||
Value *createWarpReduce(Location loc, Value *activeWidth, Value *laneId,
|
||||
Value *operand,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// TODO(csigg): Generalize to other types of accumulation.
|
||||
Value *warpSize = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
|
||||
Value *maskAndClamp = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
|
||||
Value *isPartialWarp = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize);
|
||||
auto type = operand->getType();
|
||||
|
||||
createIf(
|
||||
loc, rewriter, isPartialWarp,
|
||||
// Generate reduction over a (potentially) partial warp.
|
||||
[&] {
|
||||
Value *value = operand;
|
||||
Value *one = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(1));
|
||||
// Bit mask of active lanes: `(1 << activeWidth) - 1`.
|
||||
Value *activeMask = rewriter.create<LLVM::SubOp>(
|
||||
loc, int32Type,
|
||||
rewriter.create<LLVM::ShlOp>(loc, int32Type, one, activeWidth),
|
||||
one);
|
||||
// Bound of offsets which read from a lane within the active range.
|
||||
Value *offsetBound =
|
||||
rewriter.create<LLVM::SubOp>(loc, activeWidth, laneId);
|
||||
|
||||
// Repeatedly shuffle value from 'laneId + i' and accumulate if source
|
||||
// lane is within the active range. The first lane contains the final
|
||||
// result, all other lanes contain some undefined partial result.
|
||||
for (int i = 1; i < kWarpSize; i <<= 1) {
|
||||
Value *offset = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(i));
|
||||
// ShflDownOp instead of ShflBflyOp would produce a scan. ShflBflyOp
|
||||
// also produces the correct reduction on lane 0 though.
|
||||
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
|
||||
loc, type, activeMask, value, offset, maskAndClamp);
|
||||
// TODO(csigg): use the second result from the shuffle op instead.
|
||||
Value *isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::ICmpPredicate::slt, offset, offsetBound);
|
||||
// Skip the accumulation if the shuffle op read from a lane outside
|
||||
// of the active range.
|
||||
createIf(
|
||||
loc, rewriter, isActiveSrcLane,
|
||||
[&] {
|
||||
return llvm::SmallVector<Value *, 1>{
|
||||
rewriter.create<LLVM::FAddOp>(loc, type, value, shfl)};
|
||||
},
|
||||
[&] { return llvm::makeArrayRef(value); });
|
||||
value = rewriter.getInsertionBlock()->getArgument(0);
|
||||
}
|
||||
return llvm::SmallVector<Value *, 1>{value};
|
||||
},
|
||||
// Generate a reduction over the entire warp. This is a specialization
|
||||
// of the above reduction with unconditional accumulation.
|
||||
[&] {
|
||||
Value *value = operand;
|
||||
Value *activeMask = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(~0u));
|
||||
for (int i = 1; i < kWarpSize; i <<= 1) {
|
||||
Value *offset = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(i));
|
||||
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
|
||||
loc, type, activeMask, value, offset, maskAndClamp);
|
||||
value = rewriter.create<LLVM::FAddOp>(loc, type, value, shfl);
|
||||
}
|
||||
return llvm::SmallVector<Value *, 1>{value};
|
||||
});
|
||||
return rewriter.getInsertionBlock()->getArgument(0);
|
||||
}
|
||||
|
||||
// Creates a global array stored in shared memory.
|
||||
//
|
||||
// llvm.mlir.global @reduce_buffer()
|
||||
// {addr_space = 3 : i32} : !llvm<"[32 x float]">
|
||||
//
|
||||
Value *createSharedMemoryArray(Location loc, ModuleOp module,
|
||||
LLVM::LLVMType elementType, int numElements,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
@ -247,16 +308,6 @@ private:
|
|||
return rewriter.create<LLVM::AddOp>(loc, int32Type, tmp3, idX);
|
||||
}
|
||||
|
||||
// Returns the number of warps in the block.
|
||||
Value *getNumWarps(Location loc, ConversionPatternRewriter &rewriter) const {
|
||||
auto blockSize = getBlockSize(loc, rewriter);
|
||||
auto warpSizeMinusOne = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
|
||||
auto biasedBlockSize = rewriter.create<LLVM::AddOp>(
|
||||
loc, int32Type, blockSize, warpSizeMinusOne);
|
||||
return getDivideByWarpSize(biasedBlockSize, rewriter);
|
||||
}
|
||||
|
||||
// Returns the number of threads in the block.
|
||||
Value *getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const {
|
||||
Value *dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
|
||||
|
@ -266,6 +317,27 @@ private:
|
|||
return rewriter.create<LLVM::MulOp>(loc, int32Type, dimXY, dimZ);
|
||||
}
|
||||
|
||||
// Returns the number of warps in the block.
|
||||
Value *getNumWarps(Location loc, Value *blockSize,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto warpSizeMinusOne = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
|
||||
auto biasedBlockSize = rewriter.create<LLVM::AddOp>(
|
||||
loc, int32Type, blockSize, warpSizeMinusOne);
|
||||
return getDivideByWarpSize(biasedBlockSize, rewriter);
|
||||
}
|
||||
|
||||
// Returns the number of active threads in the warp, not clamped to 32.
|
||||
Value *getActiveWidth(Location loc, Value *threadIdx, Value *blockSize,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value *threadIdxMask = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1)));
|
||||
Value *numThreadsWithSmallerWarpId =
|
||||
rewriter.create<LLVM::AndOp>(loc, threadIdx, threadIdxMask);
|
||||
return rewriter.create<LLVM::SubOp>(loc, blockSize,
|
||||
numThreadsWithSmallerWarpId);
|
||||
}
|
||||
|
||||
// Returns value divided by the warp size (i.e. 32).
|
||||
Value *getDivideByWarpSize(Value *value,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
@ -277,7 +349,6 @@ private:
|
|||
|
||||
LLVM::LLVMType int32Type;
|
||||
|
||||
// TODO(csigg): Support other warp sizes.
|
||||
static constexpr int kWarpSize = 32;
|
||||
};
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext --entry-point-result=void | FileCheck %s
|
||||
|
||||
// CHECK: [8.128000e+03, 8.128000e+03, {{.*}}, 8.128000e+03, 8.128000e+03]
|
||||
// CHECK: [5.356000e+03, 5.356000e+03, {{.*}}, 5.356000e+03, 5.356000e+03]
|
||||
func @main() {
|
||||
%arg = alloc() : memref<16x4x2xf32>
|
||||
%dst = memref_cast %arg : memref<16x4x2xf32> to memref<?x?x?xf32>
|
||||
%arg = alloc() : memref<13x4x2xf32>
|
||||
%dst = memref_cast %arg : memref<13x4x2xf32> to memref<?x?x?xf32>
|
||||
%zero = constant 0 : i32
|
||||
%one = constant 1 : index
|
||||
%sx = dim %dst, 0 : memref<?x?x?xf32>
|
||||
|
|
Loading…
Reference in New Issue