forked from OSchip/llvm-project
Change all_reduce lowering to support 2D and 3D blocks.
Perform second reduce only with first warp. This requires an additional __sync_threads(), but doesn't need special handling when the last warp is small. This simplifies support for block sizes that are not multiple of 32. Supporting partial warp reduce will be done in a separate CL. PiperOrigin-RevId: 272168917
This commit is contained in:
parent
8503ffbe3a
commit
1129931a62
|
@ -123,30 +123,46 @@ private:
|
||||||
//
|
//
|
||||||
// First reduce the elements within a warp. The first thread of each warp
|
// First reduce the elements within a warp. The first thread of each warp
|
||||||
// writes the intermediate result to shared memory. After synchronizing the
|
// writes the intermediate result to shared memory. After synchronizing the
|
||||||
// block, each warp reduces all values from shared memory.
|
// block, the first warp reduces the values from shared memory. The result
|
||||||
|
// is broadcasted to all threads through shared memory.
|
||||||
//
|
//
|
||||||
// %warp_reduce = ... (see createWarpReduce)
|
// %warp_reduce = `createWarpReduce(%operand)`
|
||||||
// %buffer = llvm.mlir.addressof @reduce_buffer : !llvm<"[32 x float]*">
|
// %shared_mem_ptr = llvm.mlir.addressof @reduce_buffer
|
||||||
// %zero = llvm.mlir.constant(0 : i32) : !llvm.i32
|
// %zero = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
// %lane_id = nvvm.read.ptx.sreg.laneid : !llvm.i32
|
// %lane_id = nvvm.read.ptx.sreg.laneid : !llvm.i32
|
||||||
// %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i32
|
// %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i1
|
||||||
// llvm.cond_br %is_first_lane, ^then, ^continue
|
// %thread_idx = `getLinearThreadIndex()` : !llvm.i32
|
||||||
// ^then:
|
// llvm.cond_br %is_first_lane, ^then1, ^continue1
|
||||||
// %warp_id = ... (see getWarpId)
|
// ^then1:
|
||||||
// %store_dst = llvm.getelementptr %buffer[%zero, %warp_id]
|
// %warp_id = `getWarpId()`
|
||||||
// llvm.store %store_dst, %warp_reduce : !llvm.float
|
// %store_dst = llvm.getelementptr %shared_mem_ptr[%zero, %warp_id]
|
||||||
// llvm.br ^continue
|
// llvm.store %store_dst, %warp_reduce
|
||||||
// ^continue:
|
// llvm.br ^continue1
|
||||||
|
// ^continue1:
|
||||||
// nvvm.barrier0
|
// nvvm.barrier0
|
||||||
// %load_src = llvm.getelementptr %buffer[%zero, %lane_id]
|
// %num_warps = `getNumWarps()` : !llvm.i32
|
||||||
// %value = llvm.load %load_src : !llvm.float
|
// %is_valid_warp = llvm.icmp "slt" %thread_idx, %num_warps
|
||||||
// %result = ... (see createWarpReduce)
|
// %result_ptr = llvm.getelementptr %shared_mem_ptr[%zero, %zero]
|
||||||
|
// llvm.cond_br %is_first_lane, ^then2, ^continue2
|
||||||
|
// ^then2:
|
||||||
|
// %load_src = llvm.getelementptr %shared_mem_ptr[%zero, %thread_idx]
|
||||||
|
// %value = llvm.load %load_src
|
||||||
|
// %result = `createWarpReduce(%value)`
|
||||||
|
// llvm.store %result_ptr, %result
|
||||||
|
// llvm.br ^continue2
|
||||||
|
// ^continue2:
|
||||||
|
// nvvm.barrier0
|
||||||
|
// %result = llvm.load %result_ptr
|
||||||
|
// return %result
|
||||||
|
//
|
||||||
Value *createBlockReduce(Location loc, Value *operand,
|
Value *createBlockReduce(Location loc, Value *operand,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto type = operand->getType().cast<LLVM::LLVMType>();
|
auto type = operand->getType().cast<LLVM::LLVMType>();
|
||||||
|
|
||||||
|
// Reduce elements within each warp to produce the intermediate results.
|
||||||
Value *warpReduce = createWarpReduce(loc, operand, rewriter);
|
Value *warpReduce = createWarpReduce(loc, operand, rewriter);
|
||||||
|
|
||||||
|
// Create shared memory array to store the warp reduction.
|
||||||
auto module = warpReduce->getDefiningOp()->getParentOfType<ModuleOp>();
|
auto module = warpReduce->getDefiningOp()->getParentOfType<ModuleOp>();
|
||||||
assert(module && "op must belong to a module");
|
assert(module && "op must belong to a module");
|
||||||
Value *sharedMemPtr =
|
Value *sharedMemPtr =
|
||||||
|
@ -157,7 +173,59 @@ private:
|
||||||
Value *laneId = rewriter.create<NVVM::LaneIdOp>(loc, int32Type);
|
Value *laneId = rewriter.create<NVVM::LaneIdOp>(loc, int32Type);
|
||||||
Value *isFirstLane = rewriter.create<LLVM::ICmpOp>(
|
Value *isFirstLane = rewriter.create<LLVM::ICmpOp>(
|
||||||
loc, LLVM::ICmpPredicate::eq, laneId, zero);
|
loc, LLVM::ICmpPredicate::eq, laneId, zero);
|
||||||
|
Value *threadIdx = getLinearThreadIndex(loc, rewriter);
|
||||||
|
|
||||||
|
// Write the intermediate results to shared memory, using the first lane of
|
||||||
|
// each warp.
|
||||||
|
createPredicatedBlock(
|
||||||
|
loc, isFirstLane,
|
||||||
|
[&] {
|
||||||
|
Value *warpId = getDivideByWarpSize(threadIdx, rewriter);
|
||||||
|
Value *storeDst = rewriter.create<LLVM::GEPOp>(
|
||||||
|
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, warpId}));
|
||||||
|
rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
|
||||||
|
},
|
||||||
|
rewriter);
|
||||||
|
|
||||||
|
rewriter.create<NVVM::Barrier0Op>(loc);
|
||||||
|
Value *numWarps = getNumWarps(loc, rewriter);
|
||||||
|
Value *isValidWarp = rewriter.create<LLVM::ICmpOp>(
|
||||||
|
loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps);
|
||||||
|
Value *resultPtr = rewriter.create<LLVM::GEPOp>(
|
||||||
|
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, zero}));
|
||||||
|
|
||||||
|
// Use the first numWarps threads to reduce the intermediate results from
|
||||||
|
// shared memory. The final result is written to shared memory again.
|
||||||
|
createPredicatedBlock(
|
||||||
|
loc, isValidWarp,
|
||||||
|
[&] {
|
||||||
|
Value *loadSrc = rewriter.create<LLVM::GEPOp>(
|
||||||
|
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, threadIdx}));
|
||||||
|
Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
|
||||||
|
Value *result = createWarpReduce(loc, value, rewriter);
|
||||||
|
rewriter.create<LLVM::StoreOp>(loc, result, resultPtr);
|
||||||
|
},
|
||||||
|
rewriter);
|
||||||
|
|
||||||
|
rewriter.create<NVVM::Barrier0Op>(loc);
|
||||||
|
Value *result = rewriter.create<LLVM::LoadOp>(loc, type, resultPtr);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates an if-block skeleton to perform conditional execution of the
|
||||||
|
// instructions generated by predicatedOpsFactory.
|
||||||
|
//
|
||||||
|
// llvm.cond_br %condition, ^then, ^continue
|
||||||
|
// ^then:
|
||||||
|
// ... code created in `predicatedOpsFactory()`
|
||||||
|
// llvm.br ^continue
|
||||||
|
// ^continue:
|
||||||
|
//
|
||||||
|
template <typename Func>
|
||||||
|
void createPredicatedBlock(Location loc, Value *condition,
|
||||||
|
Func &&predicatedOpsFactory,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Block *currentBlock = rewriter.getInsertionBlock();
|
Block *currentBlock = rewriter.getInsertionBlock();
|
||||||
auto currentPoint = rewriter.getInsertionPoint();
|
auto currentPoint = rewriter.getInsertionPoint();
|
||||||
|
|
||||||
|
@ -166,25 +234,15 @@ private:
|
||||||
|
|
||||||
rewriter.setInsertionPointToEnd(currentBlock);
|
rewriter.setInsertionPointToEnd(currentBlock);
|
||||||
rewriter.create<LLVM::CondBrOp>(
|
rewriter.create<LLVM::CondBrOp>(
|
||||||
loc, llvm::makeArrayRef(isFirstLane),
|
loc, llvm::makeArrayRef(condition),
|
||||||
ArrayRef<Block *>{thenBlock, continueBlock});
|
ArrayRef<Block *>{thenBlock, continueBlock});
|
||||||
|
|
||||||
rewriter.setInsertionPointToEnd(thenBlock);
|
rewriter.setInsertionPointToEnd(thenBlock);
|
||||||
Value *warpId = getWarpId(loc, rewriter);
|
predicatedOpsFactory();
|
||||||
Value *storeDst = rewriter.create<LLVM::GEPOp>(
|
|
||||||
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, warpId}));
|
|
||||||
rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
|
|
||||||
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>(),
|
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>(),
|
||||||
llvm::makeArrayRef(continueBlock));
|
llvm::makeArrayRef(continueBlock));
|
||||||
|
|
||||||
rewriter.setInsertionPointToStart(continueBlock);
|
rewriter.setInsertionPointToStart(continueBlock);
|
||||||
rewriter.create<NVVM::Barrier0Op>(loc);
|
|
||||||
Value *loadSrc = rewriter.create<LLVM::GEPOp>(
|
|
||||||
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, laneId}));
|
|
||||||
Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
|
|
||||||
Value *result = createWarpReduce(loc, value, rewriter);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates an all_reduce across the warp. Creates a preamble
|
// Creates an all_reduce across the warp. Creates a preamble
|
||||||
|
@ -196,8 +254,8 @@ private:
|
||||||
//
|
//
|
||||||
// %offset = llvm.mlir.constant(i : i32) : !llvm.i32
|
// %offset = llvm.mlir.constant(i : i32) : !llvm.i32
|
||||||
// %value = nvvm.shfl.sync.bfly
|
// %value = nvvm.shfl.sync.bfly
|
||||||
// %active_mask, %operand, %offset, %mask_and_clamp : !llvm.float
|
// %active_mask, %operand, %offset, %mask_and_clamp
|
||||||
// %operand = llvm.fadd %operand, %value : !llvm.float
|
// %operand = llvm.fadd %operand, %value
|
||||||
//
|
//
|
||||||
// Each thread returns the same result.
|
// Each thread returns the same result.
|
||||||
//
|
//
|
||||||
|
@ -244,23 +302,46 @@ private:
|
||||||
return rewriter.create<LLVM::AddressOfOp>(loc, globalOp);
|
return rewriter.create<LLVM::AddressOfOp>(loc, globalOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the index of the warp within the block.
|
// Returns the index of the thread within the block.
|
||||||
//
|
|
||||||
// %warp_size = llvm.mlir.constant(32 : i32) : !llvm.i32
|
|
||||||
// %thread_idx = nvvm.read.ptx.sreg.tid.x : !llvm.i32
|
|
||||||
// %warp_idx = llvm.sdiv %thread_idx, %warp_size : !llvm.i32
|
|
||||||
//
|
|
||||||
Value *getWarpId(Location loc, ConversionPatternRewriter &rewriter) const {
|
|
||||||
auto warpSize = rewriter.create<LLVM::ConstantOp>(
|
|
||||||
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
|
|
||||||
auto threadIdx = getLinearThreadIndex(loc, rewriter);
|
|
||||||
return rewriter.create<LLVM::SDivOp>(loc, int32Type, threadIdx, warpSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value *getLinearThreadIndex(Location loc,
|
Value *getLinearThreadIndex(Location loc,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
// TODO(csigg): support 2- and 3-dimensional blocks.
|
Value *dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
|
||||||
return rewriter.create<NVVM::ThreadIdXOp>(loc, int32Type);
|
Value *dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
|
||||||
|
Value *idX = rewriter.create<NVVM::ThreadIdXOp>(loc, int32Type);
|
||||||
|
Value *idY = rewriter.create<NVVM::ThreadIdYOp>(loc, int32Type);
|
||||||
|
Value *idZ = rewriter.create<NVVM::ThreadIdZOp>(loc, int32Type);
|
||||||
|
Value *tmp1 = rewriter.create<LLVM::MulOp>(loc, int32Type, idZ, dimY);
|
||||||
|
Value *tmp2 = rewriter.create<LLVM::AddOp>(loc, int32Type, tmp1, idY);
|
||||||
|
Value *tmp3 = rewriter.create<LLVM::MulOp>(loc, int32Type, tmp2, dimX);
|
||||||
|
return rewriter.create<LLVM::AddOp>(loc, int32Type, tmp3, idX);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of warps in the block.
|
||||||
|
Value *getNumWarps(Location loc, ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto blockSize = getBlockSize(loc, rewriter);
|
||||||
|
auto warpSizeMinusOne = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
|
||||||
|
auto biasedBlockSize = rewriter.create<LLVM::AddOp>(
|
||||||
|
loc, int32Type, blockSize, warpSizeMinusOne);
|
||||||
|
return getDivideByWarpSize(biasedBlockSize, rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of threads in the block.
|
||||||
|
Value *getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value *dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
|
||||||
|
Value *dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
|
||||||
|
Value *dimZ = rewriter.create<NVVM::BlockDimZOp>(loc, int32Type);
|
||||||
|
Value *dimXY = rewriter.create<LLVM::MulOp>(loc, int32Type, dimX, dimY);
|
||||||
|
return rewriter.create<LLVM::MulOp>(loc, int32Type, dimXY, dimZ);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns value divided by the warp size (i.e. 32).
|
||||||
|
Value *getDivideByWarpSize(Value *value,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto loc = value->getLoc();
|
||||||
|
auto warpSize = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
|
||||||
|
return rewriter.create<LLVM::SDivOp>(loc, int32Type, value, warpSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVM::LLVMType int32Type;
|
LLVM::LLVMType int32Type;
|
||||||
|
|
|
@ -2,24 +2,30 @@
|
||||||
|
|
||||||
// CHECK: [8.128000e+03, 8.128000e+03, {{.*}}, 8.128000e+03, 8.128000e+03]
|
// CHECK: [8.128000e+03, 8.128000e+03, {{.*}}, 8.128000e+03, 8.128000e+03]
|
||||||
func @main() {
|
func @main() {
|
||||||
%arg = alloc() : memref<128xf32>
|
%arg = alloc() : memref<16x4x2xf32>
|
||||||
%dst = memref_cast %arg : memref<128xf32> to memref<?xf32>
|
%dst = memref_cast %arg : memref<16x4x2xf32> to memref<?x?x?xf32>
|
||||||
%zero = constant 0 : i32
|
%zero = constant 0 : i32
|
||||||
%one = constant 1 : index
|
%one = constant 1 : index
|
||||||
%size = dim %dst, 0 : memref<?xf32>
|
%sx = dim %dst, 0 : memref<?x?x?xf32>
|
||||||
call @mcuMemHostRegister(%dst, %zero) : (memref<?xf32>, i32) -> ()
|
%sy = dim %dst, 1 : memref<?x?x?xf32>
|
||||||
|
%sz = dim %dst, 2 : memref<?x?x?xf32>
|
||||||
|
call @mcuMemHostRegister(%dst, %zero) : (memref<?x?x?xf32>, i32) -> ()
|
||||||
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
|
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
|
||||||
threads(%tx, %ty, %tz) in (%block_x = %size, %block_y = %one, %block_z = %one)
|
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz)
|
||||||
args(%kernel_dst = %dst) : memref<?xf32> {
|
args(%kernel_dst = %dst) : memref<?x?x?xf32> {
|
||||||
%idx = index_cast %tx : index to i32
|
%t0 = muli %tz, %block_y : index
|
||||||
%val = sitofp %idx : i32 to f32
|
%t1 = addi %ty, %t0 : index
|
||||||
|
%t2 = muli %t1, %block_x : index
|
||||||
|
%idx = addi %tx, %t2 : index
|
||||||
|
%t3 = index_cast %idx : index to i32
|
||||||
|
%val = sitofp %t3 : i32 to f32
|
||||||
%sum = "gpu.all_reduce"(%val) { op = "add" } : (f32) -> (f32)
|
%sum = "gpu.all_reduce"(%val) { op = "add" } : (f32) -> (f32)
|
||||||
store %sum, %kernel_dst[%tx] : memref<?xf32>
|
store %sum, %kernel_dst[%tx, %ty, %tz] : memref<?x?x?xf32>
|
||||||
gpu.return
|
gpu.return
|
||||||
}
|
}
|
||||||
call @mcuPrintFloat(%dst) : (memref<?xf32>) -> ()
|
call @mcuPrintFloat(%dst) : (memref<?x?x?xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func @mcuMemHostRegister(%ptr : memref<?xf32>, %flags : i32)
|
func @mcuMemHostRegister(%ptr : memref<?x?x?xf32>, %flags : i32)
|
||||||
func @mcuPrintFloat(%ptr : memref<?xf32>)
|
func @mcuPrintFloat(%ptr : memref<?x?x?xf32>)
|
||||||
|
|
Loading…
Reference in New Issue