forked from OSchip/llvm-project
[mlir][StandardToSPIRV] Emulate bitwidths not supported for load op.
Summary: The current implementation in SPIRVTypeConverter just unconditionally turns everything into 32-bit if it doesn't meet the requirements of extensions or capabilities. In this case, we can load a 32-bit value and then do bit extraction to get the value. Differential Revision: https://reviews.llvm.org/D78974
This commit is contained in:
parent
0e8608b3c3
commit
6601b65aed
|
@ -97,6 +97,55 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
|
|||
return builder.getF32FloatAttr(dstVal.convertToFloat());
|
||||
}
|
||||
|
||||
/// Returns the offset of the value in `targetBits` representation. `srcIdx` is
|
||||
/// an index into a 1-D array with each element having `sourceBits`. When
|
||||
/// accessing an element in the array treating as having elements of
|
||||
/// `targetBits`, multiple values are loaded in the same time. The method
|
||||
/// returns the offset where the `srcIdx` locates in the value. For example, if
|
||||
/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
|
||||
/// located at (x % 4) * 8. Because there are four elements in one i32, and one
|
||||
/// element has 8 bits.
|
||||
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
|
||||
int targetBits, OpBuilder &builder) {
|
||||
assert(targetBits % sourceBits == 0);
|
||||
IntegerType targetType = builder.getIntegerType(targetBits);
|
||||
IntegerAttr idxAttr =
|
||||
builder.getIntegerAttr(targetType, targetBits / sourceBits);
|
||||
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
|
||||
IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
|
||||
auto srcBitsValue =
|
||||
builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
|
||||
auto m = builder.create<spirv::SModOp>(loc, srcIdx, idx);
|
||||
return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
|
||||
}
|
||||
|
||||
/// Returns an adjusted spirv::AccessChainOp. Based on the
|
||||
/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
|
||||
/// supported. During conversion if a memref of an unsupported type is used,
|
||||
/// load/stores to this memref need to be modified to use a supported higher
|
||||
/// bitwidth `targetBits` and extracting the required bits. For an accessing a
|
||||
/// 1D array (spv.array or spv.rt_array), the last index is modified to load the
|
||||
/// bits needed. The extraction of the actual bits needed are handled
|
||||
/// separately. Note that this only works for a 1-D tensor.
|
||||
static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
|
||||
spirv::AccessChainOp op,
|
||||
int sourceBits, int targetBits,
|
||||
OpBuilder &builder) {
|
||||
assert(targetBits % sourceBits == 0);
|
||||
const auto loc = op.getLoc();
|
||||
IntegerType targetType = builder.getIntegerType(targetBits);
|
||||
IntegerAttr attr =
|
||||
builder.getIntegerAttr(targetType, targetBits / sourceBits);
|
||||
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
|
||||
auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1);
|
||||
auto indices = llvm::to_vector<4>(op.indices());
|
||||
// There are two elements if this is a 1-D tensor.
|
||||
assert(indices.size() == 2);
|
||||
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
|
||||
Type t = typeConverter.convertType(op.component_ptr().getType());
|
||||
return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operation conversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -204,6 +253,16 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts std.load to spv.Load.
|
||||
class IntLoadOpPattern final : public SPIRVOpLowering<LoadOp> {
|
||||
public:
|
||||
using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts std.load to spv.Load.
|
||||
class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
|
||||
public:
|
||||
|
@ -528,13 +587,79 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|||
// LoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
LoadOpOperandAdaptor loadOperands(operands);
|
||||
auto loc = loadOp.getLoc();
|
||||
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
|
||||
if (!memrefType.getElementType().isSignlessInteger())
|
||||
return failure();
|
||||
spirv::AccessChainOp accessChainOp =
|
||||
spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
|
||||
loadOperands.indices(), loc, rewriter);
|
||||
|
||||
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
|
||||
auto dstType = typeConverter.convertType(memrefType)
|
||||
.cast<spirv::PointerType>()
|
||||
.getPointeeType()
|
||||
.cast<spirv::StructType>()
|
||||
.getElementType(0)
|
||||
.cast<spirv::ArrayType>()
|
||||
.getElementType();
|
||||
int dstBits = dstType.getIntOrFloatBitWidth();
|
||||
assert(dstBits % srcBits == 0);
|
||||
|
||||
// If the rewrited load op has the same bit width, use the loading value
|
||||
// directly.
|
||||
if (srcBits == dstBits) {
|
||||
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp,
|
||||
accessChainOp.getResult());
|
||||
return success();
|
||||
}
|
||||
|
||||
// Assume that getElementPtr() works linearizely. If it's a scalar, the method
|
||||
// still returns a linearized accessing. If the accessing is not linearized,
|
||||
// there will be offset issues.
|
||||
assert(accessChainOp.indices().size() == 2);
|
||||
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
|
||||
srcBits, dstBits, rewriter);
|
||||
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
|
||||
loc, dstType, adjustedPtr,
|
||||
loadOp.getAttrOfType<IntegerAttr>(
|
||||
spirv::attributeName<spirv::MemoryAccess>()),
|
||||
loadOp.getAttrOfType<IntegerAttr>("alignment"));
|
||||
|
||||
// Shift the bits to the rightmost.
|
||||
// ____XXXX________ -> ____________XXXX
|
||||
Value lastDim = accessChainOp.getOperation()->getOperand(
|
||||
accessChainOp.getNumOperands() - 1);
|
||||
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
|
||||
Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
|
||||
loc, spvLoadOp.getType(), spvLoadOp, offset);
|
||||
|
||||
// Apply the mask to extract corresponding bits.
|
||||
Value mask = rewriter.create<spirv::ConstantOp>(
|
||||
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
|
||||
result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
|
||||
rewriter.replaceOp(loadOp, result);
|
||||
|
||||
assert(accessChainOp.use_empty());
|
||||
rewriter.eraseOp(accessChainOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
LoadOpOperandAdaptor loadOperands(operands);
|
||||
auto loadPtr = spirv::getElementPtr(
|
||||
typeConverter, loadOp.memref().getType().cast<MemRefType>(),
|
||||
loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
|
||||
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
|
||||
if (memrefType.getElementType().isSignlessInteger())
|
||||
return failure();
|
||||
auto loadPtr =
|
||||
spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
|
||||
loadOperands.indices(), loadOp.getLoc(), rewriter);
|
||||
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
|
||||
return success();
|
||||
}
|
||||
|
@ -642,8 +767,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
|
|||
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
|
||||
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
|
||||
BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
|
||||
CmpFOpPattern, CmpIOpPattern, LoadOpPattern, ReturnOpPattern,
|
||||
SelectOpPattern, StoreOpPattern,
|
||||
CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern,
|
||||
ReturnOpPattern, SelectOpPattern, StoreOpPattern,
|
||||
TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
|
||||
TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
|
||||
TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>, XOrOpPattern>(
|
||||
|
|
|
@ -619,3 +619,115 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
|
|||
}
|
||||
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
||||
// Check that access chain indices are properly adjusted if non-32-bit types are
|
||||
// emulated via 32-bit types.
|
||||
// TODO: Test i64 type.
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: @load_i8
|
||||
func @load_i8(%arg0: memref<i8>) {
|
||||
// CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
|
||||
// CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32
|
||||
// CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
|
||||
// CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
|
||||
// CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
|
||||
// CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
|
||||
// CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
|
||||
// CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
|
||||
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
|
||||
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
|
||||
// CHECK: %[[MASK:.+]] = spv.constant 255 : i32
|
||||
// CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
|
||||
%0 = load %arg0[] : memref<i8>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @load_i16
|
||||
// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
|
||||
func @load_i16(%arg0: memref<10xi16>, %index : index) {
|
||||
// CHECK: %[[ONE:.+]] = spv.constant 1 : i32
|
||||
// CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32
|
||||
// CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
|
||||
// CHECK: %[[TWO1:.+]] = spv.constant 2 : i32
|
||||
// CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO1]] : i32
|
||||
// CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
|
||||
// CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
|
||||
// CHECK: %[[TWO2:.+]] = spv.constant 2 : i32
|
||||
// CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32
|
||||
// CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO2]] : i32
|
||||
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
|
||||
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
|
||||
// CHECK: %[[MASK:.+]] = spv.constant 65535 : i32
|
||||
// CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
|
||||
%0 = load %arg0[%index] : memref<10xi16>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @load_i32
|
||||
func @load_i32(%arg0: memref<i32>) {
|
||||
// CHECK-NOT: spv.SDiv
|
||||
// CHECK: spv.Load
|
||||
// CHECK-NOT: spv.ShiftRightArithmetic
|
||||
%0 = load %arg0[] : memref<i32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @load_f32
|
||||
func @load_f32(%arg0: memref<f32>) {
|
||||
// CHECK-NOT: spv.SDiv
|
||||
// CHECK: spv.Load
|
||||
// CHECK-NOT: spv.ShiftRightArithmetic
|
||||
%0 = load %arg0[] : memref<f32>
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
||||
// Check that access chain indices are properly adjusted if non-16/32-bit types
|
||||
// are emulated via 32-bit types.
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Int16, StorageBuffer16BitAccess, Shader],
|
||||
[SPV_KHR_storage_buffer_storage_class, SPV_KHR_16bit_storage]>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: @load_i8
|
||||
func @load_i8(%arg0: memref<i8>) {
|
||||
// CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
|
||||
// CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32
|
||||
// CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
|
||||
// CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
|
||||
// CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
|
||||
// CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
|
||||
// CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
|
||||
// CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
|
||||
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
|
||||
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
|
||||
// CHECK: %[[MASK:.+]] = spv.constant 255 : i32
|
||||
// CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
|
||||
%0 = load %arg0[] : memref<i8>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @load_i16
|
||||
func @load_i16(%arg0: memref<i16>) {
|
||||
// CHECK-NOT: spv.SDiv
|
||||
// CHECK: spv.Load
|
||||
// CHECK-NOT: spv.ShiftRightArithmetic
|
||||
%0 = load %arg0[] : memref<i16>
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
|
Loading…
Reference in New Issue