[mlir][StandardToSPIRV] Handle i1 case for lowering memref.load/store op

This patch unconditionally converts i1 types to i8 types on memrefs. If the
extensions or capabilities are not met, they will be converted to i32. Hence the
logic in IntLoadPattern and IntStorePattern are also updated.

Also added the implementation of SPIRVTypeConverter::getOptions().

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D99724
This commit is contained in:
Hanhan Wang 2021-04-08 12:15:15 -07:00
parent 04e9cd09c8
commit c361435845
5 changed files with 152 additions and 18 deletions

View File

@ -49,9 +49,13 @@ public:
/// values will be packed into one 32-bit value to be memory efficient.
bool emulateNon32BitScalarTypes;
/// The number of bits to store a boolean value. It is eight bits by
/// default.
unsigned boolNumBits;
// Note: we need this instead of inline initializers becuase of
// https://bugs.llvm.org/show_bug.cgi?id=36684
Options() : emulateNon32BitScalarTypes(true) {}
Options() : emulateNon32BitScalarTypes(true), boolNumBits(8) {}
};
explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,

View File

@ -991,6 +991,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
loadOperands.indices(), loc, rewriter);
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
bool isBool = srcBits == 1;
if (isBool)
srcBits = typeConverter.getOptions().boolNumBits;
auto dstType = typeConverter.convertType(memrefType)
.cast<spirv::PointerType>()
.getPointeeType()
@ -1044,6 +1047,18 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
shiftValue);
result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
shiftValue);
if (isBool) {
dstType = typeConverter.convertType(loadOp.getType());
mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
Value isOne = rewriter.create<spirv::IEqualOp>(loc, result, mask);
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
result = rewriter.create<spirv::SelectOp>(loc, dstType, isOne, one, zero);
} else if (result.getType().getIntOrFloatBitWidth() !=
static_cast<unsigned>(dstBits)) {
result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
}
rewriter.replaceOp(loadOp, result);
assert(accessChainOp.use_empty());
@ -1117,6 +1132,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
storeOperands.indices(), loc, rewriter);
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
bool isBool = srcBits == 1;
if (isBool)
srcBits = typeConverter.getOptions().boolNumBits;
auto dstType = typeConverter.convertType(memrefType)
.cast<spirv::PointerType>()
.getPointeeType()
@ -1156,8 +1175,14 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
Value storeVal =
shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter);
Value storeVal = storeOperands.value();
if (isBool) {
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
storeVal =
rewriter.create<spirv::SelectOp>(loc, dstType, storeVal, one, zero);
}
storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);

View File

@ -153,6 +153,10 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
#undef STORAGE_SPACE_MAP_FN
}
const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
return options;
}
#undef STORAGE_SPACE_MAP_LIST
// TODO: This is a utility function that should probably be exposed by the
@ -342,9 +346,66 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
}
static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
const SPIRVTypeConverter::Options &options,
MemRefType type) {
if (!type.hasStaticShape()) {
LLVM_DEBUG(llvm::dbgs()
<< type << " dynamic shape on i1 is not supported yet\n");
return nullptr;
}
Optional<spirv::StorageClass> storageClass =
SPIRVTypeConverter::getStorageClassForMemorySpace(
type.getMemorySpaceAsInt());
if (!storageClass) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot convert memory space\n");
return nullptr;
}
unsigned numBoolBits = options.boolNumBits;
if (numBoolBits != 8) {
LLVM_DEBUG(llvm::dbgs()
<< "using non-8-bit storage for bool types unimplemented");
return nullptr;
}
auto elementType = IntegerType::get(type.getContext(), numBoolBits)
.dyn_cast<spirv::ScalarType>();
if (!elementType)
return nullptr;
Type arrayElemType =
convertScalarType(targetEnv, options, elementType, storageClass);
if (!arrayElemType)
return nullptr;
Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
if (!arrayElemSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce converted element size\n");
return nullptr;
}
int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize;
auto arrayType =
spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
// workgroup storage class do not need the struct to be laid out explicitly.
auto structType = *storageClass == spirv::StorageClass::Workgroup
? spirv::StructType::get(arrayType)
: spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
}
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
const SPIRVTypeConverter::Options &options,
MemRefType type) {
if (type.getElementType().isa<IntegerType>() &&
type.getElementTypeBitWidth() == 1) {
return convertBoolMemrefType(targetEnv, options, type);
}
Optional<spirv::StorageClass> storageClass =
SPIRVTypeConverter::getStorageClassForMemorySpace(
type.getMemorySpaceAsInt());

View File

@ -911,12 +911,40 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
// Check that access chain indices are properly adjusted if non-32-bit types are
// emulated via 32-bit types.
// TODO: Test i1 and i64 types.
// TODO: Test i64 types.
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
} {
// CHECK-LABEL: @load_i1
func @load_i1(%arg0: memref<i1>) -> i1 {
// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
// CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32
// CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
// CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
// CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
// CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32
// CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32
// CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spv.Constant 255 : i32
// CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
// CHECK: %[[T2:.+]] = spv.Constant 24 : i32
// CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// CHECK: %[[T4:.+]] = spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
// Convert to i1 type.
// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
// CHECK: %[[ISONE:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32
// CHECK: %[[FALSE:.+]] = spv.Constant false
// CHECK: %[[TRUE:.+]] = spv.Constant true
// CHECK: %[[RES:.+]] = spv.Select %[[ISONE]], %[[TRUE]], %[[FALSE]] : i1, i1
// CHECK: spv.ReturnValue %[[RES]] : i1
%0 = memref.load %arg0[] : memref<i1>
return %0 : i1
}
// CHECK-LABEL: @load_i8
func @load_i8(%arg0: memref<i8>) {
// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
@ -982,6 +1010,31 @@ func @load_f32(%arg0: memref<f32>) {
return
}
// CHECK-LABEL: @store_i1
// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1)
func @store_i1(%arg0: memref<i1>, %value: i1) {
// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
// CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32
// CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32
// CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32
// CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32
// CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
// CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32
// CHECK: %[[ZERO1:.+]] = spv.Constant 0 : i32
// CHECK: %[[ONE1:.+]] = spv.Constant 1 : i32
// CHECK: %[[CASTED_ARG1:.+]] = spv.Select %[[ARG1]], %[[ONE1]], %[[ZERO1]] : i1, i32
// CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32
// CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
// CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32
// CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32
// CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]]
// CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
// CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
memref.store %value, %arg0[] : memref<i1>
return
}
// CHECK-LABEL: @store_i8
// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
func @store_i8(%arg0: memref<i8>, %value: i8) {

View File

@ -286,20 +286,6 @@ func @memref_mem_space(
// -----
// Check that boolean memref is not supported at the moment.
module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
} {
// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>)
func @memref_type(%arg0: memref<3xi1>) {
return
}
} // end module
// -----
// Check that using non-32-bit scalar types in interface storage classes
// requires special capability and extension: convert them to 32-bit if not
// satisfied.
@ -307,6 +293,11 @@ module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
} {
// An i1 is store in 8-bit, so 5xi1 has 40 bits, which is stored in 2xi32.
// CHECK-LABEL: spv.func @memref_1bit_type
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<2 x i32, stride=4> [0])>, StorageBuffer>
func @memref_1bit_type(%arg0: memref<5xi1>) { return }
// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i32, stride=4> [0])>, StorageBuffer>
// NOEMU-LABEL: func @memref_8bit_StorageBuffer