Add gpu.shuffle op.

This will allow us to lower most of gpu.all_reduce (when all_reduce
doesn't exist in the target dialect) within the GPU dialect, and only do
target-specific lowering for the shuffle op.

PiperOrigin-RevId: 286548256
This commit is contained in:
Christian Sigg 2019-12-20 02:52:21 -08:00 committed by A. Unique TensorFlower
parent 7811ad3c2b
commit 42d46b4efa
8 changed files with 213 additions and 2 deletions

View File

@ -26,6 +26,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
namespace mlir {

View File

@ -536,6 +536,41 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
let verifier = [{ return ::verifyAllReduce(*this); }];
}
def GPU_ShuffleOpXor : StrEnumAttrCase<"xor">;
def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr",
"Indexing modes supported by gpu.shuffle.",
[
GPU_ShuffleOpXor,
]>;
def GPU_ShuffleOp : GPU_Op<"shuffle", [NoSideEffect]>,
Arguments<(ins AnyType:$value, I32:$offset, I32:$width,
GPU_ShuffleModeAttr:$mode)>,
Results<(outs AnyType:$result, I1:$valid)> {
let summary = "Shuffles values within a subgroup.";
let description = [{
The "shuffle" op moves values to a different invocation within the same
subgroup.
For example
```
%1, %2 = gpu.shuffle %0, %offset, %width xor : f32
```
for lane k returns the value from lane `k ^ offset` and `true` if that lane
is smaller than %width. Otherwise it returns an unspecified value and
`false`. A lane is the index of an invocation relative to its subgroup.
The width specifies the number of invocations that participate in the
shuffle. The width needs to be the same for all invocations that participate
in the shuffle. Exactly the first `width` invocations of a subgroup need to
execute this op in convergence.
}];
let verifier = [{ return ::verifyShuffleOp(*this); }];
let printer = [{ printShuffleOp(p, *this); }];
let parser = [{ return parseShuffleOp(parser, result); }];
}
def GPU_BarrierOp : GPU_Op<"barrier"> {
let summary = "Synchronizes all work items of a workgroup.";
let description = [{

View File

@ -473,6 +473,64 @@ private:
static constexpr int kWarpSize = 32;
};
struct GPUShuffleOpLowering : public LLVMOpLowering {
explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
: LLVMOpLowering(gpu::ShuffleOp::getOperationName(),
lowering_.getDialect()->getContext(), lowering_) {}
/// Lowers a shuffle to the corresponding NVVM op.
///
/// Convert the `width` argument into an activeMask (a bitmask which specifies
/// which threads participate in the shuffle) and a maskAndClamp (specifying
/// the highest lane which participates in the shuffle).
///
/// %one = llvm.constant(1 : i32) : !llvm.i32
/// %shl = llvm.shl %one, %width : !llvm.i32
/// %active_mask = llvm.sub %shl, %one : !llvm.i32
/// %mask_and_clamp = llvm.sub %width, %one : !llvm.i32
/// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
/// %mask_and_clamp : !llvm<"{ float, i1 }">
/// %shfl_value = llvm.extractvalue %shfl[0 : index] :
/// !llvm<"{ float, i1 }">
/// %shfl_pred = llvm.extractvalue %shfl[1 : index] :
/// !llvm<"{ float, i1 }">
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
gpu::ShuffleOpOperandAdaptor adaptor(operands);
auto dialect = lowering.getDialect();
auto valueTy = adaptor.value()->getType().cast<LLVM::LLVMType>();
auto int32Type = LLVM::LLVMType::getInt32Ty(dialect);
auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy});
Value *one = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(1));
// Bit mask of active lanes: `(1 << activeWidth) - 1`.
Value *activeMask = rewriter.create<LLVM::SubOp>(
loc, int32Type,
rewriter.create<LLVM::ShlOp>(loc, int32Type, one, adaptor.width()),
one);
// Clamp lane: `activeWidth - 1`
Value *maskAndClamp =
rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.width(), one);
auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
loc, resultTy, activeMask, adaptor.value(), adaptor.offset(),
maskAndClamp, returnValueAndIsValidAttr);
Value *shflValue = rewriter.create<LLVM::ExtractValueOp>(
loc, valueTy, shfl, rewriter.getIndexArrayAttr(0));
Value *isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
return matchSuccess();
}
};
struct GPUFuncOpLowering : LLVMOpLowering {
explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter)
: LLVMOpLowering(gpu::GPUFuncOp::getOperationName(),
@ -688,8 +746,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
NVVM::GridDimYOp, NVVM::GridDimZOp>,
GPUAllReduceOpLowering, GPUFuncOpLowering, GPUReturnOpLowering>(
converter);
GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering,
GPUReturnOpLowering>(converter);
patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf",
"__nv_exp");
}

View File

@ -165,6 +165,47 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
return success();
}
static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
auto type = shuffleOp.value()->getType();
if (shuffleOp.result()->getType() != type) {
return shuffleOp.emitOpError()
<< "requires the same type for value operand and result";
}
if (!type.isIntOrFloat() || type.getIntOrFloatBitWidth() != 32) {
return shuffleOp.emitOpError()
<< "requires value operand type to be f32 or i32";
}
return success();
}
static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) {
p << ShuffleOp::getOperationName() << ' ';
p.printOperands(op.getOperands());
p << ' ' << op.mode() << " : ";
p.printType(op.value()->getType());
}
static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) {
SmallVector<OpAsmParser::OperandType, 3> operandInfo;
if (parser.parseOperandList(operandInfo, 3))
return failure();
StringRef mode;
if (parser.parseKeyword(&mode))
return failure();
state.addAttribute("mode", parser.getBuilder().getStringAttr(mode));
Type valueType;
Type int32Type = parser.getBuilder().getIntegerType(32);
Type int1Type = parser.getBuilder().getI1Type();
if (parser.parseColonType(valueType) ||
parser.resolveOperands(operandInfo, {valueType, int32Type, int32Type},
parser.getCurrentLocation(), state.operands) ||
parser.addTypesToList({valueType, int1Type}, state.types))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// LaunchOp
//===----------------------------------------------------------------------===//

View File

@ -74,6 +74,31 @@ module attributes {gpu.kernel_module} {
// -----
module attributes {gpu.kernel_module} {
// CHECK-LABEL: func @gpu_shuffle()
func @gpu_shuffle()
attributes { gpu.kernel } {
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
%arg0 = constant 1.0 : f32
// CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : !llvm.i32
%arg1 = constant 4 : i32
// CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : !llvm.i32
%arg2 = constant 23 : i32
// CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[#SHL:]] = llvm.shl %[[#ONE]], %[[#WIDTH]] : !llvm.i32
// CHECK: %[[#MASK:]] = llvm.sub %[[#SHL]], %[[#ONE]] : !llvm.i32
// CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : !llvm.i32
// CHECK: %[[#SHFL:]] = nvvm.shfl.sync.bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : !llvm<"{ float, i1 }">
// CHECK: llvm.extractvalue %[[#SHFL]][0 : index] : !llvm<"{ float, i1 }">
// CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm<"{ float, i1 }">
%shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (f32, i1)
std.return
}
}
// -----
module attributes {gpu.kernel_module} {
// CHECK-LABEL: func @gpu_sync()
func @gpu_sync()

View File

@ -362,6 +362,20 @@ func @reduce_incorrect_yield(%arg0 : f32) {
// -----
func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
// expected-error@+1 {{'gpu.shuffle' op requires the same type for value operand and result}}
%shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (i32, i1)
}
// -----
func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) {
// expected-error@+1 {{'gpu.shuffle' op requires value operand type to be f32 or i32}}
%shfl, %pred = gpu.shuffle %arg0, %arg1, %arg2 xor : index
}
// -----
module {
module @gpu_funcs attributes {gpu.kernel_module} {
// expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}

View File

@ -81,6 +81,11 @@ module attributes {gpu.container_module} {
%one = constant 1.0 : f32
%sum = "gpu.all_reduce"(%one) ({}) {op = "add"} : (f32) -> (f32)
%width = constant 7 : i32
%offset = constant 3 : i32
// CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} xor : f32
%shfl, %pred = gpu.shuffle %arg0, %offset, %width xor : f32
"gpu.barrier"() : () -> ()
"some_op"(%bIdX, %tIdX) : (index, index) -> ()

View File

@ -0,0 +1,32 @@
// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
// CHECK: [4, 5, 6, 7, 0, 1, 2, 3, 12, -1, -1, -1, 8]
func @main() {
%arg = alloc() : memref<13xf32>
%dst = memref_cast %arg : memref<13xf32> to memref<?xf32>
%one = constant 1 : index
%sx = dim %dst, 0 : memref<?xf32>
call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref<?xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one)
args(%kernel_dst = %dst) : memref<?xf32> {
%t0 = index_cast %tx : index to i32
%val = sitofp %t0 : i32 to f32
%width = index_cast %block_x : index to i32
%offset = constant 4 : i32
%shfl, %valid = gpu.shuffle %val, %offset, %width xor : f32
cond_br %valid, ^bb1(%shfl : f32), ^bb0
^bb0:
%m1 = constant -1.0 : f32
br ^bb1(%m1 : f32)
^bb1(%value : f32):
store %value, %kernel_dst[%tx] : memref<?xf32>
gpu.return
}
%U = memref_cast %dst : memref<?xf32> to memref<*xf32>
call @print_memref_f32(%U) : (memref<*xf32>) -> ()
return
}
func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)