forked from OSchip/llvm-project
[MLIR][SPIRVToLLVM] Implementation of spv.BitFieldInsert pattern
This patch introduces conversion pattern for `spv.BitFiledInsert` op, as well as some utility functions to facilitate code reading. Since `spv.BitFiledInsert` may take both vector and integer operands, this case was specifically handled by broadcasting values (`count` and `offset` here) to vectors. Moreover, the types had to be converted to same bitwidth in order to conform with LLVM dialect rules. This was done with `zext` when extending (Note that `count` and `offset` are treated as unsigned) and `trunc` in the opposite case. For the latter one, truncation is safe since the op is defined only when `count`/`offset`/their sum is less than the bitwidth of the result. This introduces a natural bound of the value of 64, which can be expressed as `i8`. Reviewed By: antiagainst, ftynse Differential Revision: https://reviews.llvm.org/D82639
This commit is contained in:
parent
50b25e0679
commit
03fe7eb16f
|
@ -53,7 +53,15 @@ static unsigned getBitWidth(Type type) {
|
||||||
return elementType.getIntOrFloatBitWidth();
|
return elementType.getIntOrFloatBitWidth();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates `IntegerAttribute` with all bits set for given type.
|
/// Returns the bit width of LLVMType integer or vector.
|
||||||
|
static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
|
||||||
|
return type.isVectorTy() ? type.getVectorElementType()
|
||||||
|
.getUnderlyingType()
|
||||||
|
->getIntegerBitWidth()
|
||||||
|
: type.getUnderlyingType()->getIntegerBitWidth();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates `IntegerAttribute` with all bits set for given type
|
||||||
IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
|
IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
|
||||||
if (auto vecType = type.dyn_cast<VectorType>()) {
|
if (auto vecType = type.dyn_cast<VectorType>()) {
|
||||||
auto integerType = vecType.getElementType().cast<IntegerType>();
|
auto integerType = vecType.getElementType().cast<IntegerType>();
|
||||||
|
@ -63,12 +71,132 @@ IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
|
||||||
return builder.getIntegerAttr(integerType, -1);
|
return builder.getIntegerAttr(integerType, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates `llvm.mlir.constant` with all bits set for the given type.
|
||||||
|
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
if (srcType.isa<VectorType>())
|
||||||
|
return rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, dstType,
|
||||||
|
SplatElementsAttr::get(srcType.cast<ShapedType>(),
|
||||||
|
minusOneIntegerAttribute(srcType, rewriter)));
|
||||||
|
return rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utility function for bitfiled ops:
|
||||||
|
/// - `BitFieldInsert`
|
||||||
|
/// - `BitFieldSExtract`
|
||||||
|
/// - `BitFieldUExtract`
|
||||||
|
/// Truncates or extends the value. If the bitwidth of the value is the same as
|
||||||
|
/// `dstType` bitwidth, the value remains unchanged.
|
||||||
|
static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
auto srcType = value.getType();
|
||||||
|
auto llvmType = dstType.cast<LLVM::LLVMType>();
|
||||||
|
unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
|
||||||
|
unsigned valueBitWidth =
|
||||||
|
srcType.isa<LLVM::LLVMType>()
|
||||||
|
? getLLVMTypeBitWidth(srcType.cast<LLVM::LLVMType>())
|
||||||
|
: getBitWidth(srcType);
|
||||||
|
|
||||||
|
if (valueBitWidth < targetBitWidth)
|
||||||
|
return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
|
||||||
|
// If the bit widths of `Count` and `Offset` are greater than the bit width
|
||||||
|
// of the target type, they are truncated. Truncation is safe since `Count`
|
||||||
|
// and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
|
||||||
|
// both values can be expressed in 8 bits.
|
||||||
|
if (valueBitWidth > targetBitWidth)
|
||||||
|
return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Broadcasts the value to vector with `numElements` number of elements
|
||||||
|
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
|
||||||
|
LLVMTypeConverter &typeConverter,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
auto vectorType = VectorType::get(numElements, toBroadcast.getType());
|
||||||
|
auto llvmVectorType = typeConverter.convertType(vectorType);
|
||||||
|
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
|
||||||
|
Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
|
||||||
|
for (unsigned i = 0; i < numElements; ++i) {
|
||||||
|
auto index = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
|
||||||
|
broadcasted = rewriter.create<LLVM::InsertElementOp>(
|
||||||
|
loc, llvmVectorType, broadcasted, toBroadcast, index);
|
||||||
|
}
|
||||||
|
return broadcasted;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Operation conversion
|
// Operation conversion
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
class BitFieldInsertPattern
|
||||||
|
: public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
|
||||||
|
public:
|
||||||
|
using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto srcType = op.getType();
|
||||||
|
auto dstType = this->typeConverter.convertType(srcType);
|
||||||
|
if (!dstType)
|
||||||
|
return failure();
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
// Broadcast `Offset` and `Count` to match the type of `Base` and `Insert`.
|
||||||
|
// If `Base` is of a vector type, construct a vector that has:
|
||||||
|
// - same number of elements as `Base`
|
||||||
|
// - each element has the type that is the same as the type of `Offset` or
|
||||||
|
// `Count`
|
||||||
|
// - each element has the same value as `Offset` or `Count`
|
||||||
|
Value offset;
|
||||||
|
Value count;
|
||||||
|
if (auto vectorType = srcType.dyn_cast<VectorType>()) {
|
||||||
|
unsigned numElements = vectorType.getNumElements();
|
||||||
|
offset =
|
||||||
|
broadcast(loc, op.offset(), numElements, typeConverter, rewriter);
|
||||||
|
count = broadcast(loc, op.count(), numElements, typeConverter, rewriter);
|
||||||
|
} else {
|
||||||
|
offset = op.offset();
|
||||||
|
count = op.count();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a mask with all bits set of the same type as `srcType`
|
||||||
|
Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
|
||||||
|
|
||||||
|
// Need to cast `Offset` and `Count` if their bit width is different
|
||||||
|
// from `Base` bit width.
|
||||||
|
Value optionallyCastedCount =
|
||||||
|
optionallyTruncateOrExtend(loc, count, dstType, rewriter);
|
||||||
|
Value optionallyCastedOffset =
|
||||||
|
optionallyTruncateOrExtend(loc, offset, dstType, rewriter);
|
||||||
|
|
||||||
|
// Create a mask with bits set outside [Offset, Offset + Count - 1].
|
||||||
|
Value maskShiftedByCount = rewriter.create<LLVM::ShlOp>(
|
||||||
|
loc, dstType, minusOne, optionallyCastedCount);
|
||||||
|
Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
|
||||||
|
maskShiftedByCount, minusOne);
|
||||||
|
Value maskShiftedByCountAndOffset = rewriter.create<LLVM::ShlOp>(
|
||||||
|
loc, dstType, negated, optionallyCastedOffset);
|
||||||
|
Value mask = rewriter.create<LLVM::XOrOp>(
|
||||||
|
loc, dstType, maskShiftedByCountAndOffset, minusOne);
|
||||||
|
|
||||||
|
// Extract unchanged bits from the `Base` that are outside of
|
||||||
|
// [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
|
||||||
|
Value baseAndMask =
|
||||||
|
rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
|
||||||
|
Value insertShiftedByOffset = rewriter.create<LLVM::ShlOp>(
|
||||||
|
loc, dstType, op.insert(), optionallyCastedOffset);
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
|
||||||
|
insertShiftedByOffset);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Converts SPIR-V operations that have straightforward LLVM equivalent
|
/// Converts SPIR-V operations that have straightforward LLVM equivalent
|
||||||
/// into LLVM dialect operations.
|
/// into LLVM dialect operations.
|
||||||
template <typename SPIRVOp, typename LLVMOp>
|
template <typename SPIRVOp, typename LLVMOp>
|
||||||
|
@ -380,6 +508,7 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
|
||||||
DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
|
DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
|
||||||
|
|
||||||
// Bitwise ops
|
// Bitwise ops
|
||||||
|
BitFieldInsertPattern,
|
||||||
DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
|
DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
|
||||||
DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
|
DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
|
||||||
DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
|
DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
|
||||||
|
|
|
@ -32,6 +32,84 @@ func @bitreverse_vector(%arg0: vector<4xi32>) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.BitFieldInsert
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @bitfield_insert_scalar_same_bit_width
|
||||||
|
// CHECK-SAME: %[[BASE:.*]]: !llvm.i32, %[[INSERT:.*]]: !llvm.i32, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i32
|
||||||
|
func @bitfield_insert_scalar_same_bit_width(%base: i32, %insert: i32, %offset: i32, %count: i32) {
|
||||||
|
// CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32
|
||||||
|
// CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[COUNT]] : !llvm.i32
|
||||||
|
// CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i32
|
||||||
|
// CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[OFFSET]] : !llvm.i32
|
||||||
|
// CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm.i32
|
||||||
|
// CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm.i32
|
||||||
|
// CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[OFFSET]] : !llvm.i32
|
||||||
|
// CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm.i32
|
||||||
|
%0 = spv.BitFieldInsert %base, %insert, %offset, %count : i32, i32, i32
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @bitfield_insert_scalar_smaller_bit_width
|
||||||
|
// CHECK-SAME: %[[BASE:.*]]: !llvm.i64, %[[INSERT:.*]]: !llvm.i64, %[[OFFSET:.*]]: !llvm.i8, %[[COUNT:.*]]: !llvm.i8
|
||||||
|
func @bitfield_insert_scalar_smaller_bit_width(%base: i64, %insert: i64, %offset: i8, %count: i8) {
|
||||||
|
// CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i64) : !llvm.i64
|
||||||
|
// CHECK: %[[EXT_COUNT:.*]] = llvm.zext %[[COUNT]] : !llvm.i8 to !llvm.i64
|
||||||
|
// CHECK: %[[EXT_OFFSET:.*]] = llvm.zext %[[OFFSET]] : !llvm.i8 to !llvm.i64
|
||||||
|
// CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[EXT_COUNT]] : !llvm.i64
|
||||||
|
// CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i64
|
||||||
|
// CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[EXT_OFFSET]] : !llvm.i64
|
||||||
|
// CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm.i64
|
||||||
|
// CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm.i64
|
||||||
|
// CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[EXT_OFFSET]] : !llvm.i64
|
||||||
|
// CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm.i64
|
||||||
|
%0 = spv.BitFieldInsert %base, %insert, %offset, %count : i64, i8, i8
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @bitfield_insert_scalar_greater_bit_width
|
||||||
|
// CHECK-SAME: %[[BASE:.*]]: !llvm.i16, %[[INSERT:.*]]: !llvm.i16, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i64
|
||||||
|
func @bitfield_insert_scalar_greater_bit_width(%base: i16, %insert: i16, %offset: i32, %count: i64) {
|
||||||
|
// CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i16) : !llvm.i16
|
||||||
|
// CHECK: %[[TRUNC_COUNT:.*]] = llvm.trunc %[[COUNT]] : !llvm.i64 to !llvm.i16
|
||||||
|
// CHECK: %[[TRUNC_OFFSET:.*]] = llvm.trunc %[[OFFSET]] : !llvm.i32 to !llvm.i16
|
||||||
|
// CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[TRUNC_COUNT]] : !llvm.i16
|
||||||
|
// CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i16
|
||||||
|
// CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[TRUNC_OFFSET]] : !llvm.i16
|
||||||
|
// CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm.i16
|
||||||
|
// CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm.i16
|
||||||
|
// CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[TRUNC_OFFSET]] : !llvm.i16
|
||||||
|
// CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm.i16
|
||||||
|
%0 = spv.BitFieldInsert %base, %insert, %offset, %count : i16, i32, i64
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @bitfield_insert_vector
|
||||||
|
// CHECK-SAME: %[[BASE:.*]]: !llvm<"<2 x i32>">, %[[INSERT:.*]]: !llvm<"<2 x i32>">, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i32
|
||||||
|
func @bitfield_insert_vector(%base: vector<2xi32>, %insert: vector<2xi32>, %offset: i32, %count: i32) {
|
||||||
|
// CHECK: %[[OFFSET_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
|
// CHECK: %[[OFFSET_V1:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||||
|
// CHECK: %[[OFFSET_V2:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[COUNT_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
|
// CHECK: %[[COUNT_V1:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||||
|
// CHECK: %[[COUNT_V2:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(dense<-1> : vector<2xi32>) : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[COUNT_V2]] : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[OFFSET_V2]] : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[OFFSET_V2]] : !llvm<"<2 x i32>">
|
||||||
|
// CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm<"<2 x i32>">
|
||||||
|
%0 = spv.BitFieldInsert %base, %insert, %offset, %count : vector<2xi32>, i32, i32
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// spv.BitwiseAnd
|
// spv.BitwiseAnd
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue