forked from OSchip/llvm-project
[mlir][StandardToSPIRV] Handle i1 case for lowering memref.load/store op
This patch unconditionally converts i1 types to i8 types on memrefs. If the extensions or capabilities are not met, they will be converted to i32. Hence the logic in IntLoadPattern and IntStorePattern are also updated. Also added the implementation of SPIRVTypeConverter::getOptions(). Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D99724
This commit is contained in:
parent
04e9cd09c8
commit
c361435845
|
@ -49,9 +49,13 @@ public:
|
|||
/// values will be packed into one 32-bit value to be memory efficient.
|
||||
bool emulateNon32BitScalarTypes;
|
||||
|
||||
/// The number of bits to store a boolean value. It is eight bits by
|
||||
/// default.
|
||||
unsigned boolNumBits;
|
||||
|
||||
// Note: we need this instead of inline initializers becuase of
|
||||
// https://bugs.llvm.org/show_bug.cgi?id=36684
|
||||
Options() : emulateNon32BitScalarTypes(true) {}
|
||||
Options() : emulateNon32BitScalarTypes(true), boolNumBits(8) {}
|
||||
};
|
||||
|
||||
explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
|
||||
|
|
|
@ -991,6 +991,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
|
|||
loadOperands.indices(), loc, rewriter);
|
||||
|
||||
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
|
||||
bool isBool = srcBits == 1;
|
||||
if (isBool)
|
||||
srcBits = typeConverter.getOptions().boolNumBits;
|
||||
auto dstType = typeConverter.convertType(memrefType)
|
||||
.cast<spirv::PointerType>()
|
||||
.getPointeeType()
|
||||
|
@ -1044,6 +1047,18 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
|
|||
shiftValue);
|
||||
result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
|
||||
shiftValue);
|
||||
|
||||
if (isBool) {
|
||||
dstType = typeConverter.convertType(loadOp.getType());
|
||||
mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
|
||||
Value isOne = rewriter.create<spirv::IEqualOp>(loc, result, mask);
|
||||
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
|
||||
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
|
||||
result = rewriter.create<spirv::SelectOp>(loc, dstType, isOne, one, zero);
|
||||
} else if (result.getType().getIntOrFloatBitWidth() !=
|
||||
static_cast<unsigned>(dstBits)) {
|
||||
result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
|
||||
}
|
||||
rewriter.replaceOp(loadOp, result);
|
||||
|
||||
assert(accessChainOp.use_empty());
|
||||
|
@ -1117,6 +1132,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
|||
spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
|
||||
storeOperands.indices(), loc, rewriter);
|
||||
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
|
||||
|
||||
bool isBool = srcBits == 1;
|
||||
if (isBool)
|
||||
srcBits = typeConverter.getOptions().boolNumBits;
|
||||
auto dstType = typeConverter.convertType(memrefType)
|
||||
.cast<spirv::PointerType>()
|
||||
.getPointeeType()
|
||||
|
@ -1156,8 +1175,14 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
|||
rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
|
||||
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
|
||||
|
||||
Value storeVal =
|
||||
shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter);
|
||||
Value storeVal = storeOperands.value();
|
||||
if (isBool) {
|
||||
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
|
||||
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
|
||||
storeVal =
|
||||
rewriter.create<spirv::SelectOp>(loc, dstType, storeVal, one, zero);
|
||||
}
|
||||
storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
|
||||
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
|
||||
srcBits, dstBits, rewriter);
|
||||
Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
|
||||
|
|
|
@ -153,6 +153,10 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
|
|||
#undef STORAGE_SPACE_MAP_FN
|
||||
}
|
||||
|
||||
const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
|
||||
return options;
|
||||
}
|
||||
|
||||
#undef STORAGE_SPACE_MAP_LIST
|
||||
|
||||
// TODO: This is a utility function that should probably be exposed by the
|
||||
|
@ -342,9 +346,66 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
|
|||
return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
|
||||
}
|
||||
|
||||
static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
|
||||
const SPIRVTypeConverter::Options &options,
|
||||
MemRefType type) {
|
||||
if (!type.hasStaticShape()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< type << " dynamic shape on i1 is not supported yet\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Optional<spirv::StorageClass> storageClass =
|
||||
SPIRVTypeConverter::getStorageClassForMemorySpace(
|
||||
type.getMemorySpaceAsInt());
|
||||
if (!storageClass) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< type << " illegal: cannot convert memory space\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned numBoolBits = options.boolNumBits;
|
||||
if (numBoolBits != 8) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "using non-8-bit storage for bool types unimplemented");
|
||||
return nullptr;
|
||||
}
|
||||
auto elementType = IntegerType::get(type.getContext(), numBoolBits)
|
||||
.dyn_cast<spirv::ScalarType>();
|
||||
if (!elementType)
|
||||
return nullptr;
|
||||
Type arrayElemType =
|
||||
convertScalarType(targetEnv, options, elementType, storageClass);
|
||||
if (!arrayElemType)
|
||||
return nullptr;
|
||||
Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
|
||||
if (!arrayElemSize) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< type << " illegal: cannot deduce converted element size\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
|
||||
auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize;
|
||||
auto arrayType =
|
||||
spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
|
||||
|
||||
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
|
||||
// workgroup storage class do not need the struct to be laid out explicitly.
|
||||
auto structType = *storageClass == spirv::StorageClass::Workgroup
|
||||
? spirv::StructType::get(arrayType)
|
||||
: spirv::StructType::get(arrayType, 0);
|
||||
return spirv::PointerType::get(structType, *storageClass);
|
||||
}
|
||||
|
||||
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
|
||||
const SPIRVTypeConverter::Options &options,
|
||||
MemRefType type) {
|
||||
if (type.getElementType().isa<IntegerType>() &&
|
||||
type.getElementTypeBitWidth() == 1) {
|
||||
return convertBoolMemrefType(targetEnv, options, type);
|
||||
}
|
||||
|
||||
Optional<spirv::StorageClass> storageClass =
|
||||
SPIRVTypeConverter::getStorageClassForMemorySpace(
|
||||
type.getMemorySpaceAsInt());
|
||||
|
|
|
@ -911,12 +911,40 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
|
|||
|
||||
// Check that access chain indices are properly adjusted if non-32-bit types are
|
||||
// emulated via 32-bit types.
|
||||
// TODO: Test i1 and i64 types.
|
||||
// TODO: Test i64 types.
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: @load_i1
|
||||
func @load_i1(%arg0: memref<i1>) -> i1 {
|
||||
// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
|
||||
// CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32
|
||||
// CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
|
||||
// CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
|
||||
// CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
|
||||
// CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32
|
||||
// CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32
|
||||
// CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32
|
||||
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
|
||||
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
|
||||
// CHECK: %[[MASK:.+]] = spv.Constant 255 : i32
|
||||
// CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
|
||||
// CHECK: %[[T2:.+]] = spv.Constant 24 : i32
|
||||
// CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
|
||||
// CHECK: %[[T4:.+]] = spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
|
||||
// Convert to i1 type.
|
||||
// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
|
||||
// CHECK: %[[ISONE:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32
|
||||
// CHECK: %[[FALSE:.+]] = spv.Constant false
|
||||
// CHECK: %[[TRUE:.+]] = spv.Constant true
|
||||
// CHECK: %[[RES:.+]] = spv.Select %[[ISONE]], %[[TRUE]], %[[FALSE]] : i1, i1
|
||||
// CHECK: spv.ReturnValue %[[RES]] : i1
|
||||
%0 = memref.load %arg0[] : memref<i1>
|
||||
return %0 : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @load_i8
|
||||
func @load_i8(%arg0: memref<i8>) {
|
||||
// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
|
||||
|
@ -982,6 +1010,31 @@ func @load_f32(%arg0: memref<f32>) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @store_i1
|
||||
// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1)
|
||||
func @store_i1(%arg0: memref<i1>, %value: i1) {
|
||||
// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
|
||||
// CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32
|
||||
// CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32
|
||||
// CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32
|
||||
// CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
|
||||
// CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32
|
||||
// CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
|
||||
// CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32
|
||||
// CHECK: %[[ZERO1:.+]] = spv.Constant 0 : i32
|
||||
// CHECK: %[[ONE1:.+]] = spv.Constant 1 : i32
|
||||
// CHECK: %[[CASTED_ARG1:.+]] = spv.Select %[[ARG1]], %[[ONE1]], %[[ZERO1]] : i1, i32
|
||||
// CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32
|
||||
// CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
|
||||
// CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32
|
||||
// CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32
|
||||
// CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]]
|
||||
// CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
|
||||
// CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
|
||||
memref.store %value, %arg0[] : memref<i1>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @store_i8
|
||||
// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
|
||||
func @store_i8(%arg0: memref<i8>, %value: i8) {
|
||||
|
|
|
@ -286,20 +286,6 @@ func @memref_mem_space(
|
|||
|
||||
// -----
|
||||
|
||||
// Check that boolean memref is not supported at the moment.
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>)
|
||||
func @memref_type(%arg0: memref<3xi1>) {
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
||||
// Check that using non-32-bit scalar types in interface storage classes
|
||||
// requires special capability and extension: convert them to 32-bit if not
|
||||
// satisfied.
|
||||
|
@ -307,6 +293,11 @@ module attributes {
|
|||
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
|
||||
} {
|
||||
|
||||
// An i1 is store in 8-bit, so 5xi1 has 40 bits, which is stored in 2xi32.
|
||||
// CHECK-LABEL: spv.func @memref_1bit_type
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<2 x i32, stride=4> [0])>, StorageBuffer>
|
||||
func @memref_1bit_type(%arg0: memref<5xi1>) { return }
|
||||
|
||||
// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i32, stride=4> [0])>, StorageBuffer>
|
||||
// NOEMU-LABEL: func @memref_8bit_StorageBuffer
|
||||
|
|
Loading…
Reference in New Issue