forked from OSchip/llvm-project
[mlir][spirv] Fix storing bool with proper storage capabilities
If the source value to store is bool, and we have native storage capability support for the target bitwidth, we still cannot directly store; we need to perform casting to match the target memref element's bitwidth. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D107114
This commit is contained in:
parent
567c8c7bfd
commit
9f5300c8be
|
@ -304,6 +304,11 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv", "ModuleOp"> {
|
|||
let summary = "Convert MemRef dialect to SPIR-V dialect";
|
||||
let constructor = "mlir::createConvertMemRefToSPIRVPass()";
|
||||
let dependentDialects = ["spirv::SPIRVDialect"];
|
||||
let options = [
|
||||
Option<"boolNumBits", "bool-num-bits",
|
||||
"int", /*default=*/"8",
|
||||
"The number of bits to store a boolean value">
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -119,6 +119,17 @@ static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
|
|||
return {};
|
||||
}
|
||||
|
||||
/// Casts the given `srcBool` into an integer of `dstType`.
|
||||
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
|
||||
OpBuilder &builder) {
|
||||
assert(srcBool.getType().isInteger(1));
|
||||
if (dstType.isInteger(1))
|
||||
return srcBool;
|
||||
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
|
||||
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
|
||||
return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operation conversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -336,9 +347,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
|
|||
dstType = typeConverter.convertType(loadOp.getType());
|
||||
mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
|
||||
Value isOne = rewriter.create<spirv::IEqualOp>(loc, result, mask);
|
||||
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
|
||||
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
|
||||
result = rewriter.create<spirv::SelectOp>(loc, dstType, isOne, one, zero);
|
||||
result = castBoolToIntN(loc, isOne, dstType, rewriter);
|
||||
} else if (result.getType().getIntOrFloatBitWidth() !=
|
||||
static_cast<unsigned>(dstBits)) {
|
||||
result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
|
||||
|
@ -392,6 +401,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
|||
bool isBool = srcBits == 1;
|
||||
if (isBool)
|
||||
srcBits = typeConverter.getOptions().boolNumBits;
|
||||
|
||||
Type pointeeType = typeConverter.convertType(memrefType)
|
||||
.cast<spirv::PointerType>()
|
||||
.getPointeeType();
|
||||
|
@ -406,8 +416,11 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
|||
assert(dstBits % srcBits == 0);
|
||||
|
||||
if (srcBits == dstBits) {
|
||||
Value storeVal = storeOperands.value();
|
||||
if (isBool)
|
||||
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
|
||||
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
|
||||
storeOp, accessChainOp.getResult(), storeOperands.value());
|
||||
storeOp, accessChainOp.getResult(), storeVal);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -435,12 +448,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
|||
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
|
||||
|
||||
Value storeVal = storeOperands.value();
|
||||
if (isBool) {
|
||||
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
|
||||
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
|
||||
storeVal =
|
||||
rewriter.create<spirv::SelectOp>(loc, dstType, storeVal, one, zero);
|
||||
}
|
||||
if (isBool)
|
||||
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
|
||||
storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
|
||||
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
|
||||
srcBits, dstBits, rewriter);
|
||||
|
|
|
@ -34,7 +34,9 @@ void ConvertMemRefToSPIRVPass::runOnOperation() {
|
|||
std::unique_ptr<ConversionTarget> target =
|
||||
SPIRVConversionTarget::get(targetAttr);
|
||||
|
||||
SPIRVTypeConverter typeConverter(targetAttr);
|
||||
SPIRVTypeConverter::Options options;
|
||||
options.boolNumBits = this->boolNumBits;
|
||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||
|
||||
// Use UnrealizedConversionCast as the bridge so that we don't need to pull in
|
||||
// patterns for other dialects.
|
||||
|
|
|
@ -1,9 +1,17 @@
|
|||
// RUN: mlir-opt -split-input-file -convert-memref-to-spirv %s -o - | FileCheck %s
|
||||
// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" %s -o - | FileCheck %s
|
||||
|
||||
// Check that with proper compute and storage extensions, we don't need to
|
||||
// perform special tricks.
|
||||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Shader, Int8, Int16, Int64, Float16, Float64],
|
||||
[SPV_KHR_storage_buffer_storage_class]>, {}>
|
||||
#spv.vce<v1.0,
|
||||
[
|
||||
Shader, Int8, Int16, Int64, Float16, Float64,
|
||||
StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
|
||||
StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8
|
||||
],
|
||||
[SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class]>, {}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: @load_store_zero_rank_float
|
||||
|
@ -57,6 +65,27 @@ func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?x
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @store_i1
|
||||
// CHECK-SAME: %[[DST:.+]]: memref<4xi1>,
|
||||
// CHECK-SAME: %[[IDX:.+]]: index
|
||||
func @store_i1(%dst: memref<4xi1>, %i: index) {
|
||||
%true = constant true
|
||||
// CHECK: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
|
||||
// CHECK: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
|
||||
// CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
|
||||
// CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
|
||||
// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
|
||||
// CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32
|
||||
// CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32
|
||||
// CHECK: %[[ADDR:.+]] = spv.AccessChain %[[DST_CAST]][%[[ZERO_0]], %[[ADD]]] : !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>, i32, i32
|
||||
// CHECK: %[[ZERO_I8:.+]] = spv.Constant 0 : i8
|
||||
// CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
|
||||
// CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
|
||||
// CHECK: spv.Store "StorageBuffer" %[[ADDR]], %[[RES]] : i8
|
||||
memref.store %true, %dst[%i]: memref<4xi1>
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
@ -88,10 +117,7 @@ func @load_i1(%arg0: memref<i1>) -> i1 {
|
|||
// CHECK: %[[T4:.+]] = spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
|
||||
// Convert to i1 type.
|
||||
// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
|
||||
// CHECK: %[[ISONE:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32
|
||||
// CHECK: %[[FALSE:.+]] = spv.Constant false
|
||||
// CHECK: %[[TRUE:.+]] = spv.Constant true
|
||||
// CHECK: %[[RES:.+]] = spv.Select %[[ISONE]], %[[TRUE]], %[[FALSE]] : i1, i1
|
||||
// CHECK: %[[RES:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = memref.load %arg0[] : memref<i1>
|
||||
return %0 : i1
|
||||
|
|
Loading…
Reference in New Issue