Support reduction of partial warps.

gpu.all_reduce now supports block sizes that are not multiple of 32.

PiperOrigin-RevId: 273255204
This commit is contained in:
Christian Sigg 2019-10-07 03:30:34 -07:00 committed by A. Unique TensorFlower
parent 77672c9777
commit 7c765d97f9
2 changed files with 165 additions and 94 deletions

View File

@ -90,11 +90,8 @@ private:
ConversionPatternRewriter &rewriter) const {
auto type = operand->getType().cast<LLVM::LLVMType>();
// Reduce elements within each warp to produce the intermediate results.
Value *warpReduce = createWarpReduce(loc, operand, rewriter);
// Create shared memory array to store the warp reduction.
auto module = warpReduce->getDefiningOp()->getParentOfType<ModuleOp>();
auto module = operand->getDefiningOp()->getParentOfType<ModuleOp>();
assert(module && "op must belong to a module");
Value *sharedMemPtr =
createSharedMemoryArray(loc, module, type, kWarpSize, rewriter);
@ -105,21 +102,24 @@ private:
Value *isFirstLane = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::eq, laneId, zero);
Value *threadIdx = getLinearThreadIndex(loc, rewriter);
Value *blockSize = getBlockSize(loc, rewriter);
Value *activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter);
// Reduce elements within each warp to produce the intermediate results.
Value *warpReduce =
createWarpReduce(loc, activeWidth, laneId, operand, rewriter);
// Write the intermediate results to shared memory, using the first lane of
// each warp.
createPredicatedBlock(
loc, isFirstLane,
[&] {
Value *warpId = getDivideByWarpSize(threadIdx, rewriter);
Value *storeDst = rewriter.create<LLVM::GEPOp>(
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, warpId}));
rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
},
rewriter);
createPredicatedBlock(loc, rewriter, isFirstLane, [&] {
Value *warpId = getDivideByWarpSize(threadIdx, rewriter);
Value *storeDst = rewriter.create<LLVM::GEPOp>(
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, warpId}));
rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
});
rewriter.create<NVVM::Barrier0Op>(loc);
Value *numWarps = getNumWarps(loc, rewriter);
Value *numWarps = getNumWarps(loc, blockSize, rewriter);
Value *isValidWarp = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps);
Value *resultPtr = rewriter.create<LLVM::GEPOp>(
@ -127,95 +127,156 @@ private:
// Use the first numWarps threads to reduce the intermediate results from
// shared memory. The final result is written to shared memory again.
createPredicatedBlock(
loc, isValidWarp,
[&] {
Value *loadSrc = rewriter.create<LLVM::GEPOp>(
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, threadIdx}));
Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
Value *result = createWarpReduce(loc, value, rewriter);
rewriter.create<LLVM::StoreOp>(loc, result, resultPtr);
},
rewriter);
createPredicatedBlock(loc, rewriter, isValidWarp, [&] {
Value *loadSrc = rewriter.create<LLVM::GEPOp>(
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, threadIdx}));
Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
Value *result = createWarpReduce(loc, numWarps, laneId, value, rewriter);
rewriter.create<LLVM::StoreOp>(loc, result, resultPtr);
});
rewriter.create<NVVM::Barrier0Op>(loc);
Value *result = rewriter.create<LLVM::LoadOp>(loc, type, resultPtr);
// Load and return result from shared memory.
Value *result = rewriter.create<LLVM::LoadOp>(loc, type, resultPtr);
return result;
}
// Creates an if-block skeleton to perform conditional execution of the
// instructions generated by predicatedOpsFactory.
// Creates an if-block skeleton and calls the two factories to generate the
// ops in the `then` and `else` block..
//
// llvm.cond_br %condition, ^then, ^continue
// ^then:
// ... code created in `predicatedOpsFactory()`
// llvm.br ^continue
// ^continue:
// %then_operands = `thenOpsFactory()`
// llvm.br ^continue(%then_operands)
// ^else:
// %else_operands = `elseOpsFactory()`
// llvm.br ^continue(%else_operands)
// ^continue(%block_operands):
//
template <typename Func>
void createPredicatedBlock(Location loc, Value *condition,
Func &&predicatedOpsFactory,
ConversionPatternRewriter &rewriter) const {
template <typename ThenOpsFactory, typename ElseOpsFactory>
void createIf(Location loc, ConversionPatternRewriter &rewriter,
Value *condition, ThenOpsFactory &&thenOpsFactory,
ElseOpsFactory &&elseOpsFactory) const {
Block *currentBlock = rewriter.getInsertionBlock();
auto currentPoint = rewriter.getInsertionPoint();
Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
Block *continueBlock = rewriter.splitBlock(thenBlock, currentPoint);
Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(
loc, llvm::makeArrayRef(condition),
ArrayRef<Block *>{thenBlock, continueBlock});
rewriter.create<LLVM::CondBrOp>(loc, llvm::makeArrayRef(condition),
ArrayRef<Block *>{thenBlock, elseBlock});
rewriter.setInsertionPointToEnd(thenBlock);
predicatedOpsFactory();
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>(),
llvm::makeArrayRef(continueBlock));
auto addBranch = [&](ArrayRef<Value *> operands) {
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>{},
llvm::makeArrayRef(continueBlock),
llvm::makeArrayRef(operands));
};
rewriter.setInsertionPointToStart(thenBlock);
auto thenOperands = thenOpsFactory();
addBranch(thenOperands);
rewriter.setInsertionPointToStart(elseBlock);
auto elseOperands = elseOpsFactory();
addBranch(elseOperands);
assert(thenOperands.size() == elseOperands.size());
rewriter.setInsertionPointToStart(continueBlock);
for (auto *operand : thenOperands)
continueBlock->addArgument(operand->getType());
}
// Creates an all_reduce across the warp. Creates a preamble
//
// %active_mask = llvm.mlir.constant(-1 : i32) : !llvm.i32
// %mask_and_clamp = llvm.mlir.constant(31 : i32) : !llvm.i32
//
// plus the accumulation for i = 1, 2, 4, 8, 16:
//
// %offset = llvm.mlir.constant(i : i32) : !llvm.i32
// %value = nvvm.shfl.sync.bfly
// %active_mask, %operand, %offset, %mask_and_clamp
// %operand = llvm.fadd %operand, %value
//
// Each thread returns the same result.
//
// Note: this currently only supports reducing exactly 32 values.
Value *createWarpReduce(Location loc, Value *operand,
ConversionPatternRewriter &rewriter) const {
// TODO(csigg): Generalize to partial warps and other types of accumulation.
static_assert(kWarpSize == 32, "Only warp size of 32 is supported.");
auto activeMask = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(~0u));
auto maskAndClamp = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
// Shortcut for createIf with empty else block and no block operands.
template <typename Factory>
void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter,
Value *condition,
Factory &&predicatedOpsFactory) const {
createIf(
loc, rewriter, condition,
[&] {
predicatedOpsFactory();
return ArrayRef<Value *>();
},
[&] { return ArrayRef<Value *>(); });
}
auto resultType = operand->getType();
for (int i = 1; i < kWarpSize; i <<= 1) {
auto offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
auto value = rewriter.create<NVVM::ShflBflyOp>(
loc, resultType, activeMask, operand, offset, maskAndClamp);
operand = rewriter.create<LLVM::FAddOp>(loc, resultType, operand, value);
}
return operand;
// Creates a reduction across the first activeWidth lanes of a warp.
// The first lane returns the result, all others return values are undefined.
Value *createWarpReduce(Location loc, Value *activeWidth, Value *laneId,
Value *operand,
ConversionPatternRewriter &rewriter) const {
// TODO(csigg): Generalize to other types of accumulation.
Value *warpSize = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
Value *maskAndClamp = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
Value *isPartialWarp = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize);
auto type = operand->getType();
createIf(
loc, rewriter, isPartialWarp,
// Generate reduction over a (potentially) partial warp.
[&] {
Value *value = operand;
Value *one = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(1));
// Bit mask of active lanes: `(1 << activeWidth) - 1`.
Value *activeMask = rewriter.create<LLVM::SubOp>(
loc, int32Type,
rewriter.create<LLVM::ShlOp>(loc, int32Type, one, activeWidth),
one);
// Bound of offsets which read from a lane within the active range.
Value *offsetBound =
rewriter.create<LLVM::SubOp>(loc, activeWidth, laneId);
// Repeatedly shuffle value from 'laneId + i' and accumulate if source
// lane is within the active range. The first lane contains the final
// result, all other lanes contain some undefined partial result.
for (int i = 1; i < kWarpSize; i <<= 1) {
Value *offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
// ShflDownOp instead of ShflBflyOp would produce a scan. ShflBflyOp
// also produces the correct reduction on lane 0 though.
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
loc, type, activeMask, value, offset, maskAndClamp);
// TODO(csigg): use the second result from the shuffle op instead.
Value *isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::slt, offset, offsetBound);
// Skip the accumulation if the shuffle op read from a lane outside
// of the active range.
createIf(
loc, rewriter, isActiveSrcLane,
[&] {
return llvm::SmallVector<Value *, 1>{
rewriter.create<LLVM::FAddOp>(loc, type, value, shfl)};
},
[&] { return llvm::makeArrayRef(value); });
value = rewriter.getInsertionBlock()->getArgument(0);
}
return llvm::SmallVector<Value *, 1>{value};
},
// Generate a reduction over the entire warp. This is a specialization
// of the above reduction with unconditional accumulation.
[&] {
Value *value = operand;
Value *activeMask = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(~0u));
for (int i = 1; i < kWarpSize; i <<= 1) {
Value *offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
loc, type, activeMask, value, offset, maskAndClamp);
value = rewriter.create<LLVM::FAddOp>(loc, type, value, shfl);
}
return llvm::SmallVector<Value *, 1>{value};
});
return rewriter.getInsertionBlock()->getArgument(0);
}
// Creates a global array stored in shared memory.
//
// llvm.mlir.global @reduce_buffer()
// {addr_space = 3 : i32} : !llvm<"[32 x float]">
//
Value *createSharedMemoryArray(Location loc, ModuleOp module,
LLVM::LLVMType elementType, int numElements,
ConversionPatternRewriter &rewriter) const {
@ -247,16 +308,6 @@ private:
return rewriter.create<LLVM::AddOp>(loc, int32Type, tmp3, idX);
}
// Returns the number of warps in the block.
Value *getNumWarps(Location loc, ConversionPatternRewriter &rewriter) const {
auto blockSize = getBlockSize(loc, rewriter);
auto warpSizeMinusOne = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
auto biasedBlockSize = rewriter.create<LLVM::AddOp>(
loc, int32Type, blockSize, warpSizeMinusOne);
return getDivideByWarpSize(biasedBlockSize, rewriter);
}
// Returns the number of threads in the block.
Value *getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const {
Value *dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
@ -266,6 +317,27 @@ private:
return rewriter.create<LLVM::MulOp>(loc, int32Type, dimXY, dimZ);
}
// Returns the number of warps in the block.
Value *getNumWarps(Location loc, Value *blockSize,
ConversionPatternRewriter &rewriter) const {
auto warpSizeMinusOne = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
auto biasedBlockSize = rewriter.create<LLVM::AddOp>(
loc, int32Type, blockSize, warpSizeMinusOne);
return getDivideByWarpSize(biasedBlockSize, rewriter);
}
// Returns the number of active threads in the warp, not clamped to 32.
Value *getActiveWidth(Location loc, Value *threadIdx, Value *blockSize,
ConversionPatternRewriter &rewriter) const {
Value *threadIdxMask = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1)));
Value *numThreadsWithSmallerWarpId =
rewriter.create<LLVM::AndOp>(loc, threadIdx, threadIdxMask);
return rewriter.create<LLVM::SubOp>(loc, blockSize,
numThreadsWithSmallerWarpId);
}
// Returns value divided by the warp size (i.e. 32).
Value *getDivideByWarpSize(Value *value,
ConversionPatternRewriter &rewriter) const {
@ -277,7 +349,6 @@ private:
LLVM::LLVMType int32Type;
// TODO(csigg): Support other warp sizes.
static constexpr int kWarpSize = 32;
};

View File

@ -1,9 +1,9 @@
// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext --entry-point-result=void | FileCheck %s
// CHECK: [8.128000e+03, 8.128000e+03, {{.*}}, 8.128000e+03, 8.128000e+03]
// CHECK: [5.356000e+03, 5.356000e+03, {{.*}}, 5.356000e+03, 5.356000e+03]
func @main() {
%arg = alloc() : memref<16x4x2xf32>
%dst = memref_cast %arg : memref<16x4x2xf32> to memref<?x?x?xf32>
%arg = alloc() : memref<13x4x2xf32>
%dst = memref_cast %arg : memref<13x4x2xf32> to memref<?x?x?xf32>
%zero = constant 0 : i32
%one = constant 1 : index
%sx = dim %dst, 0 : memref<?x?x?xf32>