Switch from shfl.bfly to shfl.down.

Both work for the current use case, but the latter allows implementing
prefix sums and is a little easier to understand for partial warps.

PiperOrigin-RevId: 285145287
This commit is contained in:
Christian Sigg 2019-12-12 01:27:27 -08:00 committed by A. Unique TensorFlower
parent 851a8516d3
commit f68ac464d8
8 changed files with 34 additions and 34 deletions

View File

@ -90,8 +90,8 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
}
def NVVM_ShflBflyOp :
NVVM_Op<"shfl.sync.bfly">,
def NVVM_ShflDownOp :
NVVM_Op<"shfl.sync.down">,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_Type:$dst,
LLVM_Type:$val,
@ -99,12 +99,12 @@ def NVVM_ShflBflyOp :
LLVM_Type:$mask_and_clamp,
OptionalAttr<UnitAttr>:$return_value_and_is_valid)> {
string llvmBuilder = [{
auto intId = getShflBflyIntrinsicId(
auto intId = getShflDownIntrinsicId(
$_resultType, static_cast<bool>($return_value_and_is_valid));
$res = createIntrinsicCall(builder,
intId, {$dst, $val, $offset, $mask_and_clamp});
}];
let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
let parser = [{ return parseNVVMShflSyncDownOp(parser, result); }];
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
let verifier = [{
if (!getAttrOfType<UnitAttr>("return_value_and_is_valid"))

View File

@ -337,7 +337,7 @@ private:
for (int i = 1; i < kWarpSize; i <<= 1) {
Value *offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
Value *shfl = rewriter.create<NVVM::ShflDownOp>(
loc, shflTy, activeMask, value, offset, maskAndClamp,
returnValueAndIsValidAttr);
Value *isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
@ -366,7 +366,7 @@ private:
for (int i = 1; i < kWarpSize; i <<= 1) {
Value *offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
Value *shflValue = rewriter.create<NVVM::ShflBflyOp>(
Value *shflValue = rewriter.create<NVVM::ShflDownOp>(
loc, type, activeMask, value, offset, maskAndClamp,
/*return_value_and_is_valid=*/UnitAttr());
value = accumFactory(loc, value, shflValue, rewriter);

View File

@ -70,9 +70,9 @@ static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
}
// <operation> ::=
// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
// `llvm.nvvm.shfl.sync.down %dst, %val, %offset, %clamp_and_mask`
// ({return_value_and_is_valid})? : result_type
static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
static ParseResult parseNVVMShflSyncDownOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 8> ops;
Type resultType;

View File

@ -44,15 +44,15 @@ static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
return builder.CreateCall(fn, args);
}
static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
static llvm::Intrinsic::ID getShflDownIntrinsicId(llvm::Type *resultType,
bool withPredicate) {
if (withPredicate) {
resultType = cast<llvm::StructType>(resultType)->getElementType(0);
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
: llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
}
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
: llvm::Intrinsic::nvvm_shfl_sync_down_i32;
}
class ModuleTranslation : public LLVM::ModuleTranslation {

View File

@ -44,7 +44,7 @@ module attributes {gpu.kernel_module} {
attributes { gpu.kernel } {
%arg0 = constant 1.0 : f32
// TODO(csigg): Check full IR expansion once lowering has settled.
// CHECK: nvvm.shfl.sync.bfly
// CHECK: nvvm.shfl.sync.down
// CHECK: nvvm.barrier0
// CHECK: llvm.fadd
%result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32)
@ -61,7 +61,7 @@ module attributes {gpu.kernel_module} {
attributes { gpu.kernel } {
%arg0 = constant 1 : i32
// TODO(csigg): Check full IR expansion once lowering has settled.
// CHECK: nvvm.shfl.sync.bfly
// CHECK: nvvm.shfl.sync.down
// CHECK: nvvm.barrier0
%result = "gpu.all_reduce"(%arg0) ({
^bb(%lhs : i32, %rhs : i32):

View File

@ -268,7 +268,7 @@ func @null_non_llvm_type() {
// CHECK-LABEL: @nvvm_invalid_shfl_pred_1
func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32
%0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32
}
// -----
@ -276,7 +276,7 @@ func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !ll
// CHECK-LABEL: @nvvm_invalid_shfl_pred_2
func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }">
%0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }">
}
// -----
@ -284,7 +284,7 @@ func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !ll
// CHECK-LABEL: @nvvm_invalid_shfl_pred_3
func @nvvm_invalid_shfl_pred_3(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }">
%0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }">
}
// -----

View File

@ -37,20 +37,20 @@ func @llvm.nvvm.barrier0() {
func @nvvm_shfl(
%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
%arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm.i32 {
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : !llvm.i32
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float
%1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : !llvm.float
// CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32
%0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 : !llvm.i32
// CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float
%1 = nvvm.shfl.sync.down %arg0, %arg4, %arg1, %arg2 : !llvm.float
llvm.return %0 : !llvm.i32
}
func @nvvm_shfl_pred(
%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
%arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }">
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }">
%1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
// CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }">
%0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
// CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }">
%1 = nvvm.shfl.sync.down %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
llvm.return %0 : !llvm<"{ i32, i1 }">
}

View File

@ -41,20 +41,20 @@ llvm.func @llvm.nvvm.barrier0() {
llvm.func @nvvm_shfl(
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 {
// CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 : !llvm.i32
// CHECK: call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 : !llvm.float
// CHECK: call i32 @llvm.nvvm.shfl.sync.down.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%6 = nvvm.shfl.sync.down %0, %3, %1, %2 : !llvm.i32
// CHECK: call float @llvm.nvvm.shfl.sync.down.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%7 = nvvm.shfl.sync.down %0, %4, %1, %2 : !llvm.float
llvm.return %6 : !llvm.i32
}
llvm.func @nvvm_shfl_pred(
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
// CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
// CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
// CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.down.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%6 = nvvm.shfl.sync.down %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
// CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.down.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%7 = nvvm.shfl.sync.down %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
llvm.return %6 : !llvm<"{ i32, i1 }">
}