[mlir][StandardToSPIRV] Emulate bitwidths not supported for load op.

Summary:
The current implementation in SPIRVTypeConverter just unconditionally turns
everything into 32-bit if it doesn't meet the requirements of extensions or
capabilities. In this case, we can load a 32-bit value and then do bit
extraction to get the value.

Differential Revision: https://reviews.llvm.org/D78974
This commit is contained in:
Hanhan Wang 2020-04-30 19:27:08 -07:00
parent 0e8608b3c3
commit 6601b65aed
2 changed files with 242 additions and 5 deletions

View File

@ -97,6 +97,55 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
/// Returns the offset of the value in `targetBits` representation. `srcIdx` is
/// an index into a 1-D array with each element having `sourceBits`. When
/// accessing an element in the array treating as having elements of
/// `targetBits`, multiple values are loaded in the same time. The method
/// returns the offset where the `srcIdx` locates in the value. For example, if
/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
/// located at (x % 4) * 8. Because there are four elements in one i32, and one
/// element has 8 bits.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
int targetBits, OpBuilder &builder) {
assert(targetBits % sourceBits == 0);
IntegerType targetType = builder.getIntegerType(targetBits);
IntegerAttr idxAttr =
builder.getIntegerAttr(targetType, targetBits / sourceBits);
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
auto srcBitsValue =
builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
auto m = builder.create<spirv::SModOp>(loc, srcIdx, idx);
return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
}
/// Returns an adjusted spirv::AccessChainOp. Based on the
/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
/// supported. During conversion if a memref of an unsupported type is used,
/// load/stores to this memref need to be modified to use a supported higher
/// bitwidth `targetBits` and extracting the required bits. For an accessing a
/// 1D array (spv.array or spv.rt_array), the last index is modified to load the
/// bits needed. The extraction of the actual bits needed are handled
/// separately. Note that this only works for a 1-D tensor.
static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
spirv::AccessChainOp op,
int sourceBits, int targetBits,
OpBuilder &builder) {
assert(targetBits % sourceBits == 0);
const auto loc = op.getLoc();
IntegerType targetType = builder.getIntegerType(targetBits);
IntegerAttr attr =
builder.getIntegerAttr(targetType, targetBits / sourceBits);
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1);
auto indices = llvm::to_vector<4>(op.indices());
// There are two elements if this is a 1-D tensor.
assert(indices.size() == 2);
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.component_ptr().getType());
return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
}
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
@ -204,6 +253,16 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.load to spv.Load.
class IntLoadOpPattern final : public SPIRVOpLowering<LoadOp> {
public:
using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
LogicalResult
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.load to spv.Load.
class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
public:
@ -528,13 +587,79 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
// LoadOp
//===----------------------------------------------------------------------===//
LogicalResult
IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LoadOpOperandAdaptor loadOperands(operands);
auto loc = loadOp.getLoc();
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
if (!memrefType.getElementType().isSignlessInteger())
return failure();
spirv::AccessChainOp accessChainOp =
spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
loadOperands.indices(), loc, rewriter);
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
auto dstType = typeConverter.convertType(memrefType)
.cast<spirv::PointerType>()
.getPointeeType()
.cast<spirv::StructType>()
.getElementType(0)
.cast<spirv::ArrayType>()
.getElementType();
int dstBits = dstType.getIntOrFloatBitWidth();
assert(dstBits % srcBits == 0);
// If the rewrited load op has the same bit width, use the loading value
// directly.
if (srcBits == dstBits) {
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp,
accessChainOp.getResult());
return success();
}
// Assume that getElementPtr() works linearizely. If it's a scalar, the method
// still returns a linearized accessing. If the accessing is not linearized,
// there will be offset issues.
assert(accessChainOp.indices().size() == 2);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
loc, dstType, adjustedPtr,
loadOp.getAttrOfType<IntegerAttr>(
spirv::attributeName<spirv::MemoryAccess>()),
loadOp.getAttrOfType<IntegerAttr>("alignment"));
// Shift the bits to the rightmost.
// ____XXXX________ -> ____________XXXX
Value lastDim = accessChainOp.getOperation()->getOperand(
accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
loc, spvLoadOp.getType(), spvLoadOp, offset);
// Apply the mask to extract corresponding bits.
Value mask = rewriter.create<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
rewriter.replaceOp(loadOp, result);
assert(accessChainOp.use_empty());
rewriter.eraseOp(accessChainOp);
return success();
}
LogicalResult
LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LoadOpOperandAdaptor loadOperands(operands);
auto loadPtr = spirv::getElementPtr(
typeConverter, loadOp.memref().getType().cast<MemRefType>(),
loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
if (memrefType.getElementType().isSignlessInteger())
return failure();
auto loadPtr =
spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
loadOperands.indices(), loadOp.getLoc(), rewriter);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
return success();
}
@ -642,8 +767,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
CmpFOpPattern, CmpIOpPattern, LoadOpPattern, ReturnOpPattern,
SelectOpPattern, StoreOpPattern,
CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern,
ReturnOpPattern, SelectOpPattern, StoreOpPattern,
TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>, XOrOpPattern>(

View File

@ -619,3 +619,115 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
}
} // end module
// -----
// Check that access chain indices are properly adjusted if non-32-bit types are
// emulated via 32-bit types.
// TODO: Test i64 type.
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
} {
// CHECK-LABEL: @load_i8
func @load_i8(%arg0: memref<i8>) {
// CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
// CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32
// CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
// CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
// CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
// CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
// CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
// CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spv.constant 255 : i32
// CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
%0 = load %arg0[] : memref<i8>
return
}
// CHECK-LABEL: @load_i16
// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
func @load_i16(%arg0: memref<10xi16>, %index : index) {
// CHECK: %[[ONE:.+]] = spv.constant 1 : i32
// CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32
// CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
// CHECK: %[[TWO1:.+]] = spv.constant 2 : i32
// CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO1]] : i32
// CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
// CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
// CHECK: %[[TWO2:.+]] = spv.constant 2 : i32
// CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32
// CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO2]] : i32
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spv.constant 65535 : i32
// CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
%0 = load %arg0[%index] : memref<10xi16>
return
}
// CHECK-LABEL: @load_i32
func @load_i32(%arg0: memref<i32>) {
// CHECK-NOT: spv.SDiv
// CHECK: spv.Load
// CHECK-NOT: spv.ShiftRightArithmetic
%0 = load %arg0[] : memref<i32>
return
}
// CHECK-LABEL: @load_f32
func @load_f32(%arg0: memref<f32>) {
// CHECK-NOT: spv.SDiv
// CHECK: spv.Load
// CHECK-NOT: spv.ShiftRightArithmetic
%0 = load %arg0[] : memref<f32>
return
}
} // end module
// -----
// Check that access chain indices are properly adjusted if non-16/32-bit types
// are emulated via 32-bit types.
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Int16, StorageBuffer16BitAccess, Shader],
[SPV_KHR_storage_buffer_storage_class, SPV_KHR_16bit_storage]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
} {
// CHECK-LABEL: @load_i8
func @load_i8(%arg0: memref<i8>) {
// CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
// CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32
// CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
// CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
// CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
// CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
// CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
// CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spv.constant 255 : i32
// CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
%0 = load %arg0[] : memref<i8>
return
}
// CHECK-LABEL: @load_i16
func @load_i16(%arg0: memref<i16>) {
// CHECK-NOT: spv.SDiv
// CHECK: spv.Load
// CHECK-NOT: spv.ShiftRightArithmetic
%0 = load %arg0[] : memref<i16>
return
}
} // end module