[mlir][spirv] Fix storing bool with proper storage capabilities

If the source value to store is bool, and we have native storage
capability support for the target bitwidth, we still cannot directly
store; we need to perform casting to match the target memref
element's bitwidth.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D107114
This commit is contained in:
Lei Zhang 2021-07-29 20:21:44 -04:00
parent 567c8c7bfd
commit 9f5300c8be
4 changed files with 60 additions and 18 deletions

View File

@ -304,6 +304,11 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv", "ModuleOp"> {
let summary = "Convert MemRef dialect to SPIR-V dialect";
let constructor = "mlir::createConvertMemRefToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
let options = [
Option<"boolNumBits", "bool-num-bits",
"int", /*default=*/"8",
"The number of bits to store a boolean value">
];
}
//===----------------------------------------------------------------------===//

View File

@ -119,6 +119,17 @@ static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
return {};
}
/// Casts the given `srcBool` into an integer of `dstType`.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
OpBuilder &builder) {
assert(srcBool.getType().isInteger(1));
if (dstType.isInteger(1))
return srcBool;
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
}
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
@ -336,9 +347,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
dstType = typeConverter.convertType(loadOp.getType());
mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
Value isOne = rewriter.create<spirv::IEqualOp>(loc, result, mask);
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
result = rewriter.create<spirv::SelectOp>(loc, dstType, isOne, one, zero);
result = castBoolToIntN(loc, isOne, dstType, rewriter);
} else if (result.getType().getIntOrFloatBitWidth() !=
static_cast<unsigned>(dstBits)) {
result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
@ -392,6 +401,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
bool isBool = srcBits == 1;
if (isBool)
srcBits = typeConverter.getOptions().boolNumBits;
Type pointeeType = typeConverter.convertType(memrefType)
.cast<spirv::PointerType>()
.getPointeeType();
@ -406,8 +416,11 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
assert(dstBits % srcBits == 0);
if (srcBits == dstBits) {
Value storeVal = storeOperands.value();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
storeOp, accessChainOp.getResult(), storeOperands.value());
storeOp, accessChainOp.getResult(), storeVal);
return success();
}
@ -435,12 +448,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
Value storeVal = storeOperands.value();
if (isBool) {
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
storeVal =
rewriter.create<spirv::SelectOp>(loc, dstType, storeVal, one, zero);
}
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);

View File

@ -34,7 +34,9 @@ void ConvertMemRefToSPIRVPass::runOnOperation() {
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
SPIRVTypeConverter::Options options;
options.boolNumBits = this->boolNumBits;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull in
// patterns for other dialects.

View File

@ -1,9 +1,17 @@
// RUN: mlir-opt -split-input-file -convert-memref-to-spirv %s -o - | FileCheck %s
// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" %s -o - | FileCheck %s
// Check that with proper compute and storage extensions, we don't need to
// perform special tricks.
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader, Int8, Int16, Int64, Float16, Float64],
[SPV_KHR_storage_buffer_storage_class]>, {}>
#spv.vce<v1.0,
[
Shader, Int8, Int16, Int64, Float16, Float64,
StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8
],
[SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class]>, {}>
} {
// CHECK-LABEL: @load_store_zero_rank_float
@ -57,6 +65,27 @@ func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?x
return
}
// CHECK-LABEL: func @store_i1
// CHECK-SAME: %[[DST:.+]]: memref<4xi1>,
// CHECK-SAME: %[[IDX:.+]]: index
func @store_i1(%dst: memref<4xi1>, %i: index) {
%true = constant true
// CHECK: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
// CHECK: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
// CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
// CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
// CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32
// CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32
// CHECK: %[[ADDR:.+]] = spv.AccessChain %[[DST_CAST]][%[[ZERO_0]], %[[ADD]]] : !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>, i32, i32
// CHECK: %[[ZERO_I8:.+]] = spv.Constant 0 : i8
// CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
// CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
// CHECK: spv.Store "StorageBuffer" %[[ADDR]], %[[RES]] : i8
memref.store %true, %dst[%i]: memref<4xi1>
return
}
} // end module
// -----
@ -88,10 +117,7 @@ func @load_i1(%arg0: memref<i1>) -> i1 {
// CHECK: %[[T4:.+]] = spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
// Convert to i1 type.
// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
// CHECK: %[[ISONE:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32
// CHECK: %[[FALSE:.+]] = spv.Constant false
// CHECK: %[[TRUE:.+]] = spv.Constant true
// CHECK: %[[RES:.+]] = spv.Select %[[ISONE]], %[[TRUE]], %[[FALSE]] : i1, i1
// CHECK: %[[RES:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32
// CHECK: return %[[RES]]
%0 = memref.load %arg0[] : memref<i1>
return %0 : i1