From c3e56cd12cf6d4ab3223d402370dc9236acd0f1b Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Sat, 19 Oct 2019 01:52:51 -0700 Subject: [PATCH] Get active source lane predicate from shuffle instruction. nvvm.shfl.sync.bfly optionally returns a predicate whether source lane was active. Support for this was added to clang in https://reviews.llvm.org/D68892. Add an optional 'pred' unit attribute to the instruction to return this predicate. Specify this attribute in the partial warp reduction so we don't need to manually compute the predicate. PiperOrigin-RevId: 275616564 --- .../include/mlir/Dialect/LLVMIR/LLVMDialect.h | 5 +++ mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 18 +++++++--- .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 36 ++++++++++--------- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 ++ mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 32 +++++++++++------ mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 11 ++++++ mlir/test/Dialect/LLVMIR/invalid.mlir | 23 ++++++++++++ mlir/test/Dialect/LLVMIR/nvvm.mlir | 10 ++++++ mlir/test/Target/nvvmir.mlir | 10 ++++++ 9 files changed, 116 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index e6810168bc57..9b8df748edf5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -66,6 +66,10 @@ public: /// Utilities to identify types. bool isFloatTy() { return getUnderlyingType()->isFloatTy(); } + bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); } + bool isIntegerTy(unsigned bitwidth) { + return getUnderlyingType()->isIntegerTy(bitwidth); + } /// Array type utilities. LLVMType getArrayElementType(); @@ -89,6 +93,7 @@ public: /// Struct type utilities. LLVMType getStructElementType(unsigned i); + unsigned getStructNumElements(); bool isStructTy(); /// Utilities used to generate floating point types. diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 7f011cd9d6f9..d952089e7173 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -97,16 +97,26 @@ def NVVM_ShflBflyOp : Arguments<(ins LLVM_Type:$dst, LLVM_Type:$val, LLVM_Type:$offset, - LLVM_Type:$mask_and_clamp)> { + LLVM_Type:$mask_and_clamp, + OptionalAttr:$return_value_and_is_valid)> { string llvmBuilder = [{ - auto intId = $val->getType()->isFloatTy() ? - llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 : - llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; + auto intId = getShflBflyIntrinsicId( + $_resultType, static_cast($return_value_and_is_valid)); $res = createIntrinsicCall(builder, intId, {$dst, $val, $offset, $mask_and_clamp}); }]; let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }]; let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; + let verifier = [{ + if (!getAttrOfType("return_value_and_is_valid")) + return success(); + auto type = getType().cast(); + if (!type.isStructTy() || type.getStructNumElements() != 2 || + !type.getStructElementType(1).isIntegerTy( + /*Bitwidth=*/1)) + return emitError("expected return type !llvm<\"{ ?, i1 }\">"); + return success(); + }]; } def NVVM_VoteBallotOp : diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index a1442d0c6467..462457ccca80 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -309,7 +309,7 @@ private: loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); Value *isPartialWarp = rewriter.create( loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize); - auto type = operand->getType(); + auto type = operand->getType().cast(); createIf( loc, rewriter, isPartialWarp, @@ -323,30 +323,31 @@ private: loc, int32Type, rewriter.create(loc, int32Type, one, activeWidth), one); - // Bound of offsets which read from a lane within the active range. - Value *offsetBound = - rewriter.create(loc, activeWidth, laneId); + auto dialect = lowering.getDialect(); + auto predTy = LLVM::LLVMType::getInt1Ty(dialect); + auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy}); + auto returnValueAndIsValidAttr = rewriter.getUnitAttr(); - // Repeatedly shuffle value from 'laneId + i' and accumulate if source - // lane is within the active range. The first lane contains the final - // result, all other lanes contain some undefined partial result. + // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source + // lane is within the active range. All lanes contain the final + // result, but only the first lane's result is used. for (int i = 1; i < kWarpSize; i <<= 1) { Value *offset = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(i)); - // ShflDownOp instead of ShflBflyOp would produce a scan. ShflBflyOp - // also produces the correct reduction on lane 0 though. Value *shfl = rewriter.create( - loc, type, activeMask, value, offset, maskAndClamp); - // TODO(csigg): use the second result from the shuffle op instead. - Value *isActiveSrcLane = rewriter.create( - loc, LLVM::ICmpPredicate::slt, offset, offsetBound); + loc, shflTy, activeMask, value, offset, maskAndClamp, + returnValueAndIsValidAttr); + Value *isActiveSrcLane = rewriter.create( + loc, predTy, shfl, rewriter.getIndexArrayAttr(1)); // Skip the accumulation if the shuffle op read from a lane outside // of the active range. createIf( loc, rewriter, isActiveSrcLane, [&] { + Value *shflValue = rewriter.create( + loc, type, shfl, rewriter.getIndexArrayAttr(0)); return llvm::SmallVector{ - accumFactory(loc, value, shfl, rewriter)}; + accumFactory(loc, value, shflValue, rewriter)}; }, [&] { return llvm::makeArrayRef(value); }); value = rewriter.getInsertionBlock()->getArgument(0); @@ -362,9 +363,10 @@ private: for (int i = 1; i < kWarpSize; i <<= 1) { Value *offset = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(i)); - Value *shfl = rewriter.create( - loc, type, activeMask, value, offset, maskAndClamp); - value = accumFactory(loc, value, shfl, rewriter); + Value *shflValue = rewriter.create( + loc, type, activeMask, value, offset, maskAndClamp, + /*return_value_and_is_valid=*/UnitAttr()); + value = accumFactory(loc, value, shflValue, rewriter); } return llvm::SmallVector{value}; }); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 618ee231f9e6..ad1fa7705ebe 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1373,6 +1373,9 @@ bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); } LLVMType LLVMType::getStructElementType(unsigned i) { return get(getContext(), getUnderlyingType()->getStructElementType(i)); } +unsigned LLVMType::getStructNumElements() { + return getUnderlyingType()->getStructNumElements(); +} bool LLVMType::isStructTy() { return getUnderlyingType()->isStructTy(); } /// Utilities used to generate floating point types. diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index e45e37fada36..026b9757839c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -29,6 +29,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/StandardTypes.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Attributes.h" @@ -70,20 +71,29 @@ static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) { // ::= // `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask` -// : result_type +// ({return_value_and_is_valid})? : result_type static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser, OperationState &result) { - auto llvmDialect = getLlvmDialect(parser); - auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); - SmallVector ops; - Type type; - return failure(parser.parseOperandList(ops) || - parser.parseOptionalAttributeDict(result.attributes) || - parser.parseColonType(type) || - parser.addTypeToList(type, result.types) || - parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty}, - parser.getNameLoc(), result.operands)); + Type resultType; + if (parser.parseOperandList(ops) || + parser.parseOptionalAttributeDict(result.attributes) || + parser.parseColonType(resultType) || + parser.addTypeToList(resultType, result.types)) + return failure(); + + auto type = resultType.cast(); + for (auto &attr : result.attributes) { + if (attr.first != "return_value_and_is_valid") + continue; + if (type.isStructTy() && type.getStructNumElements() > 0) + type = type.getStructElementType(0); + break; + } + + auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser)); + return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty}, + parser.getNameLoc(), result.operands); } // ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 13043d781059..606e91b955f5 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -44,6 +44,17 @@ static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder, return builder.CreateCall(fn, args); } +static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType, + bool withPredicate) { + if (withPredicate) { + resultType = cast(resultType)->getElementType(0); + return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p + : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p; + } + return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 + : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; +} + class ModuleTranslation : public LLVM::ModuleTranslation { public: diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 65d802302168..c5f3895bf059 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -263,3 +263,26 @@ func @null_non_llvm_type() { llvm.mlir.null : !llvm.i32 } +// ----- + +// CHECK-LABEL: @nvvm_invalid_shfl_pred_1 +func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) { + // expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}} + %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32 +} + +// ----- + +// CHECK-LABEL: @nvvm_invalid_shfl_pred_2 +func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) { + // expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}} + %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }"> +} + +// ----- + +// CHECK-LABEL: @nvvm_invalid_shfl_pred_3 +func @nvvm_invalid_shfl_pred_3(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) { + // expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}} + %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }"> +} diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index 8ca439d4df4f..3e2d4e5bdac6 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -44,6 +44,16 @@ func @nvvm_shfl( llvm.return %0 : !llvm.i32 } +func @nvvm_shfl_pred( + %arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, + %arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm<"{ i32, i1 }"> { + // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }"> + %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }"> + // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }"> + %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }"> + llvm.return %0 : !llvm<"{ i32, i1 }"> +} + func @nvvm_vote(%arg0 : !llvm.i32, %arg1 : !llvm.i1) -> !llvm.i32 { // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : !llvm.i32 %0 = nvvm.vote.ballot.sync %arg0, %arg1 : !llvm.i32 diff --git a/mlir/test/Target/nvvmir.mlir b/mlir/test/Target/nvvmir.mlir index c09b414bce29..fea83e01fb46 100644 --- a/mlir/test/Target/nvvmir.mlir +++ b/mlir/test/Target/nvvmir.mlir @@ -48,6 +48,16 @@ llvm.func @nvvm_shfl( llvm.return %6 : !llvm.i32 } +llvm.func @nvvm_shfl_pred( + %0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32, + %3 : !llvm.i32, %4 : !llvm.float) -> !llvm<"{ i32, i1 }"> { + // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }"> + // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }"> + llvm.return %6 : !llvm<"{ i32, i1 }"> +} + llvm.func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 { // CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}}) %3 = nvvm.vote.ballot.sync %0, %1 : !llvm.i32