forked from OSchip/llvm-project
Get active source lane predicate from shuffle instruction.
nvvm.shfl.sync.bfly optionally returns a predicate whether source lane was active. Support for this was added to clang in https://reviews.llvm.org/D68892. Add an optional 'pred' unit attribute to the instruction to return this predicate. Specify this attribute in the partial warp reduction so we don't need to manually compute the predicate. PiperOrigin-RevId: 275616564
This commit is contained in:
parent
5f6bdd144a
commit
c3e56cd12c
|
@ -66,6 +66,10 @@ public:
|
|||
|
||||
/// Utilities to identify types.
|
||||
bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
|
||||
bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); }
|
||||
bool isIntegerTy(unsigned bitwidth) {
|
||||
return getUnderlyingType()->isIntegerTy(bitwidth);
|
||||
}
|
||||
|
||||
/// Array type utilities.
|
||||
LLVMType getArrayElementType();
|
||||
|
@ -89,6 +93,7 @@ public:
|
|||
|
||||
/// Struct type utilities.
|
||||
LLVMType getStructElementType(unsigned i);
|
||||
unsigned getStructNumElements();
|
||||
bool isStructTy();
|
||||
|
||||
/// Utilities used to generate floating point types.
|
||||
|
|
|
@ -97,16 +97,26 @@ def NVVM_ShflBflyOp :
|
|||
Arguments<(ins LLVM_Type:$dst,
|
||||
LLVM_Type:$val,
|
||||
LLVM_Type:$offset,
|
||||
LLVM_Type:$mask_and_clamp)> {
|
||||
LLVM_Type:$mask_and_clamp,
|
||||
OptionalAttr<UnitAttr>:$return_value_and_is_valid)> {
|
||||
string llvmBuilder = [{
|
||||
auto intId = $val->getType()->isFloatTy() ?
|
||||
llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 :
|
||||
llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
|
||||
auto intId = getShflBflyIntrinsicId(
|
||||
$_resultType, static_cast<bool>($return_value_and_is_valid));
|
||||
$res = createIntrinsicCall(builder,
|
||||
intId, {$dst, $val, $offset, $mask_and_clamp});
|
||||
}];
|
||||
let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
|
||||
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
|
||||
let verifier = [{
|
||||
if (!getAttrOfType<UnitAttr>("return_value_and_is_valid"))
|
||||
return success();
|
||||
auto type = getType().cast<LLVM::LLVMType>();
|
||||
if (!type.isStructTy() || type.getStructNumElements() != 2 ||
|
||||
!type.getStructElementType(1).isIntegerTy(
|
||||
/*Bitwidth=*/1))
|
||||
return emitError("expected return type !llvm<\"{ ?, i1 }\">");
|
||||
return success();
|
||||
}];
|
||||
}
|
||||
|
||||
def NVVM_VoteBallotOp :
|
||||
|
|
|
@ -309,7 +309,7 @@ private:
|
|||
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
|
||||
Value *isPartialWarp = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize);
|
||||
auto type = operand->getType();
|
||||
auto type = operand->getType().cast<LLVM::LLVMType>();
|
||||
|
||||
createIf(
|
||||
loc, rewriter, isPartialWarp,
|
||||
|
@ -323,30 +323,31 @@ private:
|
|||
loc, int32Type,
|
||||
rewriter.create<LLVM::ShlOp>(loc, int32Type, one, activeWidth),
|
||||
one);
|
||||
// Bound of offsets which read from a lane within the active range.
|
||||
Value *offsetBound =
|
||||
rewriter.create<LLVM::SubOp>(loc, activeWidth, laneId);
|
||||
auto dialect = lowering.getDialect();
|
||||
auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
|
||||
auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy});
|
||||
auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
|
||||
|
||||
// Repeatedly shuffle value from 'laneId + i' and accumulate if source
|
||||
// lane is within the active range. The first lane contains the final
|
||||
// result, all other lanes contain some undefined partial result.
|
||||
// Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
|
||||
// lane is within the active range. All lanes contain the final
|
||||
// result, but only the first lane's result is used.
|
||||
for (int i = 1; i < kWarpSize; i <<= 1) {
|
||||
Value *offset = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(i));
|
||||
// ShflDownOp instead of ShflBflyOp would produce a scan. ShflBflyOp
|
||||
// also produces the correct reduction on lane 0 though.
|
||||
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
|
||||
loc, type, activeMask, value, offset, maskAndClamp);
|
||||
// TODO(csigg): use the second result from the shuffle op instead.
|
||||
Value *isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::ICmpPredicate::slt, offset, offsetBound);
|
||||
loc, shflTy, activeMask, value, offset, maskAndClamp,
|
||||
returnValueAndIsValidAttr);
|
||||
Value *isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
|
||||
// Skip the accumulation if the shuffle op read from a lane outside
|
||||
// of the active range.
|
||||
createIf(
|
||||
loc, rewriter, isActiveSrcLane,
|
||||
[&] {
|
||||
Value *shflValue = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, type, shfl, rewriter.getIndexArrayAttr(0));
|
||||
return llvm::SmallVector<Value *, 1>{
|
||||
accumFactory(loc, value, shfl, rewriter)};
|
||||
accumFactory(loc, value, shflValue, rewriter)};
|
||||
},
|
||||
[&] { return llvm::makeArrayRef(value); });
|
||||
value = rewriter.getInsertionBlock()->getArgument(0);
|
||||
|
@ -362,9 +363,10 @@ private:
|
|||
for (int i = 1; i < kWarpSize; i <<= 1) {
|
||||
Value *offset = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(i));
|
||||
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
|
||||
loc, type, activeMask, value, offset, maskAndClamp);
|
||||
value = accumFactory(loc, value, shfl, rewriter);
|
||||
Value *shflValue = rewriter.create<NVVM::ShflBflyOp>(
|
||||
loc, type, activeMask, value, offset, maskAndClamp,
|
||||
/*return_value_and_is_valid=*/UnitAttr());
|
||||
value = accumFactory(loc, value, shflValue, rewriter);
|
||||
}
|
||||
return llvm::SmallVector<Value *, 1>{value};
|
||||
});
|
||||
|
|
|
@ -1373,6 +1373,9 @@ bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); }
|
|||
LLVMType LLVMType::getStructElementType(unsigned i) {
|
||||
return get(getContext(), getUnderlyingType()->getStructElementType(i));
|
||||
}
|
||||
unsigned LLVMType::getStructNumElements() {
|
||||
return getUnderlyingType()->getStructNumElements();
|
||||
}
|
||||
bool LLVMType::isStructTy() { return getUnderlyingType()->isStructTy(); }
|
||||
|
||||
/// Utilities used to generate floating point types.
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/AsmParser/Parser.h"
|
||||
#include "llvm/IR/Attributes.h"
|
||||
|
@ -70,20 +71,29 @@ static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
|
|||
|
||||
// <operation> ::=
|
||||
// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
|
||||
// : result_type
|
||||
// ({return_value_and_is_valid})? : result_type
|
||||
static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
auto llvmDialect = getLlvmDialect(parser);
|
||||
auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
|
||||
|
||||
SmallVector<OpAsmParser::OperandType, 8> ops;
|
||||
Type type;
|
||||
return failure(parser.parseOperandList(ops) ||
|
||||
Type resultType;
|
||||
if (parser.parseOperandList(ops) ||
|
||||
parser.parseOptionalAttributeDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.addTypeToList(type, result.types) ||
|
||||
parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
|
||||
parser.getNameLoc(), result.operands));
|
||||
parser.parseColonType(resultType) ||
|
||||
parser.addTypeToList(resultType, result.types))
|
||||
return failure();
|
||||
|
||||
auto type = resultType.cast<LLVM::LLVMType>();
|
||||
for (auto &attr : result.attributes) {
|
||||
if (attr.first != "return_value_and_is_valid")
|
||||
continue;
|
||||
if (type.isStructTy() && type.getStructNumElements() > 0)
|
||||
type = type.getStructElementType(0);
|
||||
break;
|
||||
}
|
||||
|
||||
auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
|
||||
return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
|
||||
parser.getNameLoc(), result.operands);
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
|
||||
|
|
|
@ -44,6 +44,17 @@ static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
|
|||
return builder.CreateCall(fn, args);
|
||||
}
|
||||
|
||||
static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
|
||||
bool withPredicate) {
|
||||
if (withPredicate) {
|
||||
resultType = cast<llvm::StructType>(resultType)->getElementType(0);
|
||||
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
|
||||
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
|
||||
}
|
||||
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
|
||||
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
|
||||
}
|
||||
|
||||
class ModuleTranslation : public LLVM::ModuleTranslation {
|
||||
|
||||
public:
|
||||
|
|
|
@ -263,3 +263,26 @@ func @null_non_llvm_type() {
|
|||
llvm.mlir.null : !llvm.i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @nvvm_invalid_shfl_pred_1
|
||||
func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
|
||||
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
|
||||
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @nvvm_invalid_shfl_pred_2
|
||||
func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
|
||||
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
|
||||
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @nvvm_invalid_shfl_pred_3
|
||||
func @nvvm_invalid_shfl_pred_3(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
|
||||
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
|
||||
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }">
|
||||
}
|
||||
|
|
|
@ -44,6 +44,16 @@ func @nvvm_shfl(
|
|||
llvm.return %0 : !llvm.i32
|
||||
}
|
||||
|
||||
func @nvvm_shfl_pred(
|
||||
%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
|
||||
%arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
|
||||
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }">
|
||||
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
|
||||
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }">
|
||||
%1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
|
||||
llvm.return %0 : !llvm<"{ i32, i1 }">
|
||||
}
|
||||
|
||||
func @nvvm_vote(%arg0 : !llvm.i32, %arg1 : !llvm.i1) -> !llvm.i32 {
|
||||
// CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : !llvm.i32
|
||||
%0 = nvvm.vote.ballot.sync %arg0, %arg1 : !llvm.i32
|
||||
|
|
|
@ -48,6 +48,16 @@ llvm.func @nvvm_shfl(
|
|||
llvm.return %6 : !llvm.i32
|
||||
}
|
||||
|
||||
llvm.func @nvvm_shfl_pred(
|
||||
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
|
||||
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
|
||||
// CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||
%6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
|
||||
// CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||
%7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
|
||||
llvm.return %6 : !llvm<"{ i32, i1 }">
|
||||
}
|
||||
|
||||
llvm.func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
|
||||
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
|
||||
%3 = nvvm.vote.ballot.sync %0, %1 : !llvm.i32
|
||||
|
|
Loading…
Reference in New Issue