Get active source lane predicate from shuffle instruction.

nvvm.shfl.sync.bfly optionally returns a predicate whether source lane was active. Support for this was added to clang in https://reviews.llvm.org/D68892.

Add an optional 'pred' unit attribute to the instruction to return this predicate. Specify this attribute in the partial warp reduction so we don't need to manually compute the predicate.

PiperOrigin-RevId: 275616564
This commit is contained in:
Christian Sigg 2019-10-19 01:52:51 -07:00 committed by A. Unique TensorFlower
parent 5f6bdd144a
commit c3e56cd12c
9 changed files with 116 additions and 32 deletions

View File

@ -66,6 +66,10 @@ public:
/// Utilities to identify types.
bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); }
bool isIntegerTy(unsigned bitwidth) {
return getUnderlyingType()->isIntegerTy(bitwidth);
}
/// Array type utilities.
LLVMType getArrayElementType();
@ -89,6 +93,7 @@ public:
/// Struct type utilities.
LLVMType getStructElementType(unsigned i);
unsigned getStructNumElements();
bool isStructTy();
/// Utilities used to generate floating point types.

View File

@ -97,16 +97,26 @@ def NVVM_ShflBflyOp :
Arguments<(ins LLVM_Type:$dst,
LLVM_Type:$val,
LLVM_Type:$offset,
LLVM_Type:$mask_and_clamp)> {
LLVM_Type:$mask_and_clamp,
OptionalAttr<UnitAttr>:$return_value_and_is_valid)> {
string llvmBuilder = [{
auto intId = $val->getType()->isFloatTy() ?
llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 :
llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
auto intId = getShflBflyIntrinsicId(
$_resultType, static_cast<bool>($return_value_and_is_valid));
$res = createIntrinsicCall(builder,
intId, {$dst, $val, $offset, $mask_and_clamp});
}];
let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
let verifier = [{
if (!getAttrOfType<UnitAttr>("return_value_and_is_valid"))
return success();
auto type = getType().cast<LLVM::LLVMType>();
if (!type.isStructTy() || type.getStructNumElements() != 2 ||
!type.getStructElementType(1).isIntegerTy(
/*Bitwidth=*/1))
return emitError("expected return type !llvm<\"{ ?, i1 }\">");
return success();
}];
}
def NVVM_VoteBallotOp :

View File

@ -309,7 +309,7 @@ private:
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
Value *isPartialWarp = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize);
auto type = operand->getType();
auto type = operand->getType().cast<LLVM::LLVMType>();
createIf(
loc, rewriter, isPartialWarp,
@ -323,30 +323,31 @@ private:
loc, int32Type,
rewriter.create<LLVM::ShlOp>(loc, int32Type, one, activeWidth),
one);
// Bound of offsets which read from a lane within the active range.
Value *offsetBound =
rewriter.create<LLVM::SubOp>(loc, activeWidth, laneId);
auto dialect = lowering.getDialect();
auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy});
auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
// Repeatedly shuffle value from 'laneId + i' and accumulate if source
// lane is within the active range. The first lane contains the final
// result, all other lanes contain some undefined partial result.
// Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
// lane is within the active range. All lanes contain the final
// result, but only the first lane's result is used.
for (int i = 1; i < kWarpSize; i <<= 1) {
Value *offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
// ShflDownOp instead of ShflBflyOp would produce a scan. ShflBflyOp
// also produces the correct reduction on lane 0 though.
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
loc, type, activeMask, value, offset, maskAndClamp);
// TODO(csigg): use the second result from the shuffle op instead.
Value *isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::slt, offset, offsetBound);
loc, shflTy, activeMask, value, offset, maskAndClamp,
returnValueAndIsValidAttr);
Value *isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
// Skip the accumulation if the shuffle op read from a lane outside
// of the active range.
createIf(
loc, rewriter, isActiveSrcLane,
[&] {
Value *shflValue = rewriter.create<LLVM::ExtractValueOp>(
loc, type, shfl, rewriter.getIndexArrayAttr(0));
return llvm::SmallVector<Value *, 1>{
accumFactory(loc, value, shfl, rewriter)};
accumFactory(loc, value, shflValue, rewriter)};
},
[&] { return llvm::makeArrayRef(value); });
value = rewriter.getInsertionBlock()->getArgument(0);
@ -362,9 +363,10 @@ private:
for (int i = 1; i < kWarpSize; i <<= 1) {
Value *offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
loc, type, activeMask, value, offset, maskAndClamp);
value = accumFactory(loc, value, shfl, rewriter);
Value *shflValue = rewriter.create<NVVM::ShflBflyOp>(
loc, type, activeMask, value, offset, maskAndClamp,
/*return_value_and_is_valid=*/UnitAttr());
value = accumFactory(loc, value, shflValue, rewriter);
}
return llvm::SmallVector<Value *, 1>{value};
});

View File

@ -1373,6 +1373,9 @@ bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); }
LLVMType LLVMType::getStructElementType(unsigned i) {
return get(getContext(), getUnderlyingType()->getStructElementType(i));
}
unsigned LLVMType::getStructNumElements() {
return getUnderlyingType()->getStructNumElements();
}
bool LLVMType::isStructTy() { return getUnderlyingType()->isStructTy(); }
/// Utilities used to generate floating point types.

View File

@ -29,6 +29,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Attributes.h"
@ -70,20 +71,29 @@ static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
// <operation> ::=
// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
// : result_type
// ({return_value_and_is_valid})? : result_type
static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
OperationState &result) {
auto llvmDialect = getLlvmDialect(parser);
auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
SmallVector<OpAsmParser::OperandType, 8> ops;
Type type;
return failure(parser.parseOperandList(ops) ||
Type resultType;
if (parser.parseOperandList(ops) ||
parser.parseOptionalAttributeDict(result.attributes) ||
parser.parseColonType(type) ||
parser.addTypeToList(type, result.types) ||
parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
parser.getNameLoc(), result.operands));
parser.parseColonType(resultType) ||
parser.addTypeToList(resultType, result.types))
return failure();
auto type = resultType.cast<LLVM::LLVMType>();
for (auto &attr : result.attributes) {
if (attr.first != "return_value_and_is_valid")
continue;
if (type.isStructTy() && type.getStructNumElements() > 0)
type = type.getStructElementType(0);
break;
}
auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
parser.getNameLoc(), result.operands);
}
// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type

View File

@ -44,6 +44,17 @@ static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
return builder.CreateCall(fn, args);
}
static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
bool withPredicate) {
if (withPredicate) {
resultType = cast<llvm::StructType>(resultType)->getElementType(0);
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
}
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
}
class ModuleTranslation : public LLVM::ModuleTranslation {
public:

View File

@ -263,3 +263,26 @@ func @null_non_llvm_type() {
llvm.mlir.null : !llvm.i32
}
// -----
// CHECK-LABEL: @nvvm_invalid_shfl_pred_1
func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32
}
// -----
// CHECK-LABEL: @nvvm_invalid_shfl_pred_2
func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }">
}
// -----
// CHECK-LABEL: @nvvm_invalid_shfl_pred_3
func @nvvm_invalid_shfl_pred_3(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }">
}

View File

@ -44,6 +44,16 @@ func @nvvm_shfl(
llvm.return %0 : !llvm.i32
}
func @nvvm_shfl_pred(
%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
%arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }">
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }">
%1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
llvm.return %0 : !llvm<"{ i32, i1 }">
}
func @nvvm_vote(%arg0 : !llvm.i32, %arg1 : !llvm.i1) -> !llvm.i32 {
// CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : !llvm.i32
%0 = nvvm.vote.ballot.sync %arg0, %arg1 : !llvm.i32

View File

@ -48,6 +48,16 @@ llvm.func @nvvm_shfl(
llvm.return %6 : !llvm.i32
}
llvm.func @nvvm_shfl_pred(
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
// CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
// CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
llvm.return %6 : !llvm<"{ i32, i1 }">
}
llvm.func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
%3 = nvvm.vote.ballot.sync %0, %1 : !llvm.i32