forked from OSchip/llvm-project
Switch from shfl.bfly to shfl.down.
Both work for the current use case, but the latter allows implementing prefix sums and is a little easier to understand for partial warps. PiperOrigin-RevId: 285145287
This commit is contained in:
parent
851a8516d3
commit
f68ac464d8
|
@ -90,8 +90,8 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
|
|||
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
|
||||
}
|
||||
|
||||
def NVVM_ShflBflyOp :
|
||||
NVVM_Op<"shfl.sync.bfly">,
|
||||
def NVVM_ShflDownOp :
|
||||
NVVM_Op<"shfl.sync.down">,
|
||||
Results<(outs LLVM_Type:$res)>,
|
||||
Arguments<(ins LLVM_Type:$dst,
|
||||
LLVM_Type:$val,
|
||||
|
@ -99,12 +99,12 @@ def NVVM_ShflBflyOp :
|
|||
LLVM_Type:$mask_and_clamp,
|
||||
OptionalAttr<UnitAttr>:$return_value_and_is_valid)> {
|
||||
string llvmBuilder = [{
|
||||
auto intId = getShflBflyIntrinsicId(
|
||||
auto intId = getShflDownIntrinsicId(
|
||||
$_resultType, static_cast<bool>($return_value_and_is_valid));
|
||||
$res = createIntrinsicCall(builder,
|
||||
intId, {$dst, $val, $offset, $mask_and_clamp});
|
||||
}];
|
||||
let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
|
||||
let parser = [{ return parseNVVMShflSyncDownOp(parser, result); }];
|
||||
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
|
||||
let verifier = [{
|
||||
if (!getAttrOfType<UnitAttr>("return_value_and_is_valid"))
|
||||
|
|
|
@ -337,7 +337,7 @@ private:
|
|||
for (int i = 1; i < kWarpSize; i <<= 1) {
|
||||
Value *offset = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(i));
|
||||
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
|
||||
Value *shfl = rewriter.create<NVVM::ShflDownOp>(
|
||||
loc, shflTy, activeMask, value, offset, maskAndClamp,
|
||||
returnValueAndIsValidAttr);
|
||||
Value *isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
|
||||
|
@ -366,7 +366,7 @@ private:
|
|||
for (int i = 1; i < kWarpSize; i <<= 1) {
|
||||
Value *offset = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(i));
|
||||
Value *shflValue = rewriter.create<NVVM::ShflBflyOp>(
|
||||
Value *shflValue = rewriter.create<NVVM::ShflDownOp>(
|
||||
loc, type, activeMask, value, offset, maskAndClamp,
|
||||
/*return_value_and_is_valid=*/UnitAttr());
|
||||
value = accumFactory(loc, value, shflValue, rewriter);
|
||||
|
|
|
@ -70,9 +70,9 @@ static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
|
|||
}
|
||||
|
||||
// <operation> ::=
|
||||
// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
|
||||
// `llvm.nvvm.shfl.sync.down %dst, %val, %offset, %clamp_and_mask`
|
||||
// ({return_value_and_is_valid})? : result_type
|
||||
static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
|
||||
static ParseResult parseNVVMShflSyncDownOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 8> ops;
|
||||
Type resultType;
|
||||
|
|
|
@ -44,15 +44,15 @@ static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
|
|||
return builder.CreateCall(fn, args);
|
||||
}
|
||||
|
||||
static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
|
||||
static llvm::Intrinsic::ID getShflDownIntrinsicId(llvm::Type *resultType,
|
||||
bool withPredicate) {
|
||||
if (withPredicate) {
|
||||
resultType = cast<llvm::StructType>(resultType)->getElementType(0);
|
||||
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
|
||||
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
|
||||
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
|
||||
: llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
|
||||
}
|
||||
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
|
||||
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
|
||||
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
|
||||
: llvm::Intrinsic::nvvm_shfl_sync_down_i32;
|
||||
}
|
||||
|
||||
class ModuleTranslation : public LLVM::ModuleTranslation {
|
||||
|
|
|
@ -44,7 +44,7 @@ module attributes {gpu.kernel_module} {
|
|||
attributes { gpu.kernel } {
|
||||
%arg0 = constant 1.0 : f32
|
||||
// TODO(csigg): Check full IR expansion once lowering has settled.
|
||||
// CHECK: nvvm.shfl.sync.bfly
|
||||
// CHECK: nvvm.shfl.sync.down
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.fadd
|
||||
%result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32)
|
||||
|
@ -61,7 +61,7 @@ module attributes {gpu.kernel_module} {
|
|||
attributes { gpu.kernel } {
|
||||
%arg0 = constant 1 : i32
|
||||
// TODO(csigg): Check full IR expansion once lowering has settled.
|
||||
// CHECK: nvvm.shfl.sync.bfly
|
||||
// CHECK: nvvm.shfl.sync.down
|
||||
// CHECK: nvvm.barrier0
|
||||
%result = "gpu.all_reduce"(%arg0) ({
|
||||
^bb(%lhs : i32, %rhs : i32):
|
||||
|
|
|
@ -268,7 +268,7 @@ func @null_non_llvm_type() {
|
|||
// CHECK-LABEL: @nvvm_invalid_shfl_pred_1
|
||||
func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
|
||||
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
|
||||
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32
|
||||
%0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -276,7 +276,7 @@ func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !ll
|
|||
// CHECK-LABEL: @nvvm_invalid_shfl_pred_2
|
||||
func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
|
||||
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
|
||||
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }">
|
||||
%0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -284,7 +284,7 @@ func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !ll
|
|||
// CHECK-LABEL: @nvvm_invalid_shfl_pred_3
|
||||
func @nvvm_invalid_shfl_pred_3(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
|
||||
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
|
||||
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }">
|
||||
%0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -37,20 +37,20 @@ func @llvm.nvvm.barrier0() {
|
|||
func @nvvm_shfl(
|
||||
%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
|
||||
%arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm.i32 {
|
||||
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32
|
||||
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : !llvm.i32
|
||||
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float
|
||||
%1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : !llvm.float
|
||||
// CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32
|
||||
%0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 : !llvm.i32
|
||||
// CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float
|
||||
%1 = nvvm.shfl.sync.down %arg0, %arg4, %arg1, %arg2 : !llvm.float
|
||||
llvm.return %0 : !llvm.i32
|
||||
}
|
||||
|
||||
func @nvvm_shfl_pred(
|
||||
%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
|
||||
%arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
|
||||
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }">
|
||||
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
|
||||
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }">
|
||||
%1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
|
||||
// CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }">
|
||||
%0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
|
||||
// CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }">
|
||||
%1 = nvvm.shfl.sync.down %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
|
||||
llvm.return %0 : !llvm<"{ i32, i1 }">
|
||||
}
|
||||
|
||||
|
|
|
@ -41,20 +41,20 @@ llvm.func @llvm.nvvm.barrier0() {
|
|||
llvm.func @nvvm_shfl(
|
||||
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
|
||||
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 {
|
||||
// CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||
%6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 : !llvm.i32
|
||||
// CHECK: call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||
%7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 : !llvm.float
|
||||
// CHECK: call i32 @llvm.nvvm.shfl.sync.down.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||
%6 = nvvm.shfl.sync.down %0, %3, %1, %2 : !llvm.i32
|
||||
// CHECK: call float @llvm.nvvm.shfl.sync.down.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||
%7 = nvvm.shfl.sync.down %0, %4, %1, %2 : !llvm.float
|
||||
llvm.return %6 : !llvm.i32
|
||||
}
|
||||
|
||||
llvm.func @nvvm_shfl_pred(
|
||||
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
|
||||
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
|
||||
// CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||
%6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
|
||||
// CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||
%7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
|
||||
// CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.down.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||
%6 = nvvm.shfl.sync.down %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
|
||||
// CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.down.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||
%7 = nvvm.shfl.sync.down %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
|
||||
llvm.return %6 : !llvm<"{ i32, i1 }">
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue