From c3614358452e5050b5b191fd3df3fad8b2664221 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Thu, 8 Apr 2021 12:15:15 -0700 Subject: [PATCH] [mlir][StandardToSPIRV] Handle i1 case for lowering memref.load/store op This patch unconditionally converts i1 types to i8 types on memrefs. If the extensions or capabilities are not met, they will be converted to i32. Hence the logic in IntLoadPattern and IntStorePattern are also updated. Also added the implementation of SPIRVTypeConverter::getOptions(). Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D99724 --- .../SPIRV/Transforms/SPIRVConversion.h | 6 +- .../StandardToSPIRV/StandardToSPIRV.cpp | 29 ++++++++- .../SPIRV/Transforms/SPIRVConversion.cpp | 61 +++++++++++++++++++ .../StandardToSPIRV/std-ops-to-spirv.mlir | 55 ++++++++++++++++- .../StandardToSPIRV/std-types-to-spirv.mlir | 19 ++---- 5 files changed, 152 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 911a030d4a7e..2186107851d9 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -49,9 +49,13 @@ public: /// values will be packed into one 32-bit value to be memory efficient. bool emulateNon32BitScalarTypes; + /// The number of bits to store a boolean value. It is eight bits by + /// default. + unsigned boolNumBits; + // Note: we need this instead of inline initializers becuase of // https://bugs.llvm.org/show_bug.cgi?id=36684 - Options() : emulateNon32BitScalarTypes(true) {} + Options() : emulateNon32BitScalarTypes(true), boolNumBits(8) {} }; explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index ed66252e20ae..397d26b0499d 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -991,6 +991,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, loadOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + bool isBool = srcBits == 1; + if (isBool) + srcBits = typeConverter.getOptions().boolNumBits; auto dstType = typeConverter.convertType(memrefType) .cast() .getPointeeType() @@ -1044,6 +1047,18 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, shiftValue); result = rewriter.create(loc, dstType, result, shiftValue); + + if (isBool) { + dstType = typeConverter.convertType(loadOp.getType()); + mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); + Value isOne = rewriter.create(loc, result, mask); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + result = rewriter.create(loc, dstType, isOne, one, zero); + } else if (result.getType().getIntOrFloatBitWidth() != + static_cast(dstBits)) { + result = rewriter.create(loc, dstType, result); + } rewriter.replaceOp(loadOp, result); assert(accessChainOp.use_empty()); @@ -1117,6 +1132,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + + bool isBool = srcBits == 1; + if (isBool) + srcBits = typeConverter.getOptions().boolNumBits; auto dstType = typeConverter.convertType(memrefType) .cast() .getPointeeType() @@ -1156,8 +1175,14 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, rewriter.create(loc, dstType, mask, offset); clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); - Value storeVal = - shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter); + Value storeVal = storeOperands.value(); + if (isBool) { + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + storeVal = + rewriter.create(loc, dstType, storeVal, one, zero); + } + storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); Optional scope = getAtomicOpScope(memrefType); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 4e2dc01108b2..d26299fa82a9 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -153,6 +153,10 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { #undef STORAGE_SPACE_MAP_FN } +const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const { + return options; +} + #undef STORAGE_SPACE_MAP_LIST // TODO: This is a utility function that should probably be exposed by the @@ -342,9 +346,66 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); } +static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, + const SPIRVTypeConverter::Options &options, + MemRefType type) { + if (!type.hasStaticShape()) { + LLVM_DEBUG(llvm::dbgs() + << type << " dynamic shape on i1 is not supported yet\n"); + return nullptr; + } + + Optional storageClass = + SPIRVTypeConverter::getStorageClassForMemorySpace( + type.getMemorySpaceAsInt()); + if (!storageClass) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot convert memory space\n"); + return nullptr; + } + + unsigned numBoolBits = options.boolNumBits; + if (numBoolBits != 8) { + LLVM_DEBUG(llvm::dbgs() + << "using non-8-bit storage for bool types unimplemented"); + return nullptr; + } + auto elementType = IntegerType::get(type.getContext(), numBoolBits) + .dyn_cast(); + if (!elementType) + return nullptr; + Type arrayElemType = + convertScalarType(targetEnv, options, elementType, storageClass); + if (!arrayElemType) + return nullptr; + Optional arrayElemSize = getTypeNumBytes(options, arrayElemType); + if (!arrayElemSize) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot deduce converted element size\n"); + return nullptr; + } + + int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8; + auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize; + auto arrayType = + spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); + + // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with + // workgroup storage class do not need the struct to be laid out explicitly. + auto structType = *storageClass == spirv::StorageClass::Workgroup + ? spirv::StructType::get(arrayType) + : spirv::StructType::get(arrayType, 0); + return spirv::PointerType::get(structType, *storageClass); +} + static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, MemRefType type) { + if (type.getElementType().isa() && + type.getElementTypeBitWidth() == 1) { + return convertBoolMemrefType(targetEnv, options, type); + } + Optional storageClass = SPIRVTypeConverter::getStorageClassForMemorySpace( type.getMemorySpaceAsInt()); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index d074969febff..86d390a2ce70 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -911,12 +911,40 @@ func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { // Check that access chain indices are properly adjusted if non-32-bit types are // emulated via 32-bit types. -// TODO: Test i1 and i64 types. +// TODO: Test i64 types. module attributes { spv.target_env = #spv.target_env< #spv.vce, {}> } { +// CHECK-LABEL: @load_i1 +func @load_i1(%arg0: memref) -> i1 { + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Constant 255 : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.Constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: %[[T4:.+]] = spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + // Convert to i1 type. + // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK: %[[ISONE:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32 + // CHECK: %[[FALSE:.+]] = spv.Constant false + // CHECK: %[[TRUE:.+]] = spv.Constant true + // CHECK: %[[RES:.+]] = spv.Select %[[ISONE]], %[[TRUE]], %[[FALSE]] : i1, i1 + // CHECK: spv.ReturnValue %[[RES]] : i1 + %0 = memref.load %arg0[] : memref + return %0 : i1 +} + // CHECK-LABEL: @load_i8 func @load_i8(%arg0: memref) { // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 @@ -982,6 +1010,31 @@ func @load_f32(%arg0: memref) { return } +// CHECK-LABEL: @store_i1 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1) +func @store_i1(%arg0: memref, %value: i1) { + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32 + // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 + // CHECK: %[[ZERO1:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE1:.+]] = spv.Constant 1 : i32 + // CHECK: %[[CASTED_ARG1:.+]] = spv.Select %[[ARG1]], %[[ONE1]], %[[ZERO1]] : i1, i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + memref.store %value, %arg0[] : memref + return +} + // CHECK-LABEL: @store_i8 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32) func @store_i8(%arg0: memref, %value: i8) { diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir index cacc3e762c14..58513124907a 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -286,20 +286,6 @@ func @memref_mem_space( // ----- -// Check that boolean memref is not supported at the moment. -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>) -func @memref_type(%arg0: memref<3xi1>) { - return -} - -} // end module - -// ----- - // Check that using non-32-bit scalar types in interface storage classes // requires special capability and extension: convert them to 32-bit if not // satisfied. @@ -307,6 +293,11 @@ module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { +// An i1 is store in 8-bit, so 5xi1 has 40 bits, which is stored in 2xi32. +// CHECK-LABEL: spv.func @memref_1bit_type +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +func @memref_1bit_type(%arg0: memref<5xi1>) { return } + // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_8bit_StorageBuffer