[mlir][spirv] Convert memref.alloca to spv.Variable

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D124542
This commit is contained in:
Lei Zhang 2022-04-28 08:13:22 -04:00
parent 72959f7714
commit 8854b73606
3 changed files with 147 additions and 27 deletions

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
@ -85,15 +86,27 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
offset); offset);
} }
/// Returns true if the allocations of type `t` can be lowered to SPIR-V. /// Returns true if the allocations of memref `type` generated from `allocOp`
static bool isAllocationSupported(MemRefType t) { /// can be lowered to SPIR-V.
// Currently only support workgroup local memory allocations with static static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
// shape and int or float or vector of int or float element type. if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
if (!(t.hasStaticShape() && if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
SPIRVTypeConverter::getMemorySpaceForStorageClass( spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt())
spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt())) return false;
} else if (isa<memref::AllocaOp>(allocOp)) {
if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
spirv::StorageClass::Function) != type.getMemorySpaceAsInt())
return false;
} else {
return false; return false;
Type elementType = t.getElementType(); }
// Currently only support static shape and int or float or vector of int or
// float element type.
if (!type.hasStaticShape())
return false;
Type elementType = type.getElementType();
if (auto vecType = elementType.dyn_cast<VectorType>()) if (auto vecType = elementType.dyn_cast<VectorType>())
elementType = vecType.getElementType(); elementType = vecType.getElementType();
return elementType.isIntOrFloat(); return elementType.isIntOrFloat();
@ -102,10 +115,10 @@ static bool isAllocationSupported(MemRefType t) {
/// Returns the scope to use for atomic operations use for emulating store /// Returns the scope to use for atomic operations use for emulating store
/// operations of unsupported integer bitwidths, based on the memref /// operations of unsupported integer bitwidths, based on the memref
/// type. Returns None on failure. /// type. Returns None on failure.
static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) { static Optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
Optional<spirv::StorageClass> storageClass = Optional<spirv::StorageClass> storageClass =
SPIRVTypeConverter::getStorageClassForMemorySpace( SPIRVTypeConverter::getStorageClassForMemorySpace(
t.getMemorySpaceAsInt()); type.getMemorySpaceAsInt());
if (!storageClass) if (!storageClass)
return {}; return {};
switch (*storageClass) { switch (*storageClass) {
@ -149,6 +162,16 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
namespace { namespace {
/// Converts memref.alloca to SPIR-V Function variables.
class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
public:
using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts an allocation operation to SPIR-V. Currently only supports lowering /// Converts an allocation operation to SPIR-V. Currently only supports lowering
/// to Workgroup memory when the size is constant. Note that this pattern needs /// to Workgroup memory when the size is constant. Note that this pattern needs
/// to be applied in a pass that runs at least at spv.module scope since it wil /// to be applied in a pass that runs at least at spv.module scope since it wil
@ -215,6 +238,25 @@ public:
} // namespace } // namespace
//===----------------------------------------------------------------------===//
// AllocaOp
//===----------------------------------------------------------------------===//
LogicalResult
AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType allocType = allocaOp.getType();
if (!isAllocationSupported(allocaOp, allocType))
return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
// Get the SPIR-V type for the allocation.
Type spirvType = getTypeConverter()->convertType(allocType);
rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
spirv::StorageClass::Function,
/*initializer=*/nullptr);
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AllocOp // AllocOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -223,8 +265,8 @@ LogicalResult
AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
MemRefType allocType = operation.getType(); MemRefType allocType = operation.getType();
if (!isAllocationSupported(allocType)) if (!isAllocationSupported(operation, allocType))
return operation.emitError("unhandled allocation type"); return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
// Get the SPIR-V type for the allocation. // Get the SPIR-V type for the allocation.
Type spirvType = getTypeConverter()->convertType(allocType); Type spirvType = getTypeConverter()->convertType(allocType);
@ -262,8 +304,8 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
OpAdaptor adaptor, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
MemRefType deallocType = operation.memref().getType().cast<MemRefType>(); MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
if (!isAllocationSupported(deallocType)) if (!isAllocationSupported(operation, deallocType))
return operation.emitError("unhandled deallocation type"); return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
rewriter.eraseOp(operation); rewriter.eraseOp(operation);
return success(); return success();
} }
@ -505,8 +547,9 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
namespace mlir { namespace mlir {
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
patterns.add<AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, patterns
IntStoreOpPattern, LoadOpPattern, StoreOpPattern>( .add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
typeConverter, patterns.getContext()); IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
typeConverter, patterns.getContext());
} }
} // namespace mlir } // namespace mlir

View File

@ -100,10 +100,12 @@ module attributes {
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}> #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
} }
{ {
func.func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : index) { // CHECK-LABEL: func @alloc_dynamic_size
// expected-error @+1 {{unhandled allocation type}} func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
// CHECK: memref.alloc
%0 = memref.alloc(%arg0) : memref<4x?xf32, 3> %0 = memref.alloc(%arg0) : memref<4x?xf32, 3>
return %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 3>
return %1: f32
} }
} }
@ -114,10 +116,12 @@ module attributes {
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}> #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
} }
{ {
func.func @alloc_dealloc_mem() { // CHECK-LABEL: func @alloc_unsupported_memory_space
// expected-error @+1 {{unhandled allocation type}} func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 {
// CHECK: memref.alloc
%0 = memref.alloc() : memref<4x5xf32> %0 = memref.alloc() : memref<4x5xf32>
return %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32>
return %1: f32
} }
} }
@ -129,8 +133,9 @@ module attributes {
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}> #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
} }
{ {
func.func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : memref<4x?xf32, 3>) { // CHECK-LABEL: func @dealloc_dynamic_size
// expected-error @+1 {{unhandled deallocation type}} func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, 3>) {
// CHECK: memref.dealloc
memref.dealloc %arg0 : memref<4x?xf32, 3> memref.dealloc %arg0 : memref<4x?xf32, 3>
return return
} }
@ -143,8 +148,9 @@ module attributes {
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}> #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
} }
{ {
func.func @alloc_dealloc_mem(%arg0 : memref<4x5xf32>) { // CHECK-LABEL: func @dealloc_unsupported_memory_space
// expected-error @+1 {{unhandled deallocation type}} func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32>) {
// CHECK: memref.dealloc
memref.dealloc %arg0 : memref<4x5xf32> memref.dealloc %arg0 : memref<4x5xf32>
return return
} }

View File

@ -0,0 +1,71 @@
// RUN: mlir-opt -split-input-file -convert-memref-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
func.func @alloc_function_variable(%arg0 : index, %arg1 : index) {
%0 = memref.alloca() : memref<4x5xf32, 6>
%1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, 6>
memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, 6>
return
}
}
// CHECK-LABEL: func @alloc_function_variable
// CHECK: %[[VAR:.+]] = spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Function>
// CHECK: %[[LOADPTR:.+]] = spv.AccessChain %[[VAR]]
// CHECK: %[[VAL:.+]] = spv.Load "Function" %[[LOADPTR]] : f32
// CHECK: %[[STOREPTR:.+]] = spv.AccessChain %[[VAR]]
// CHECK: spv.Store "Function" %[[STOREPTR]], %[[VAL]] : f32
// -----
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
func.func @two_allocs() {
%0 = memref.alloca() : memref<4x5xf32, 6>
%1 = memref.alloca() : memref<2x3xi32, 6>
return
}
}
// CHECK-LABEL: func @two_allocs
// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<6 x i32, stride=4>)>, Function>
// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Function>
// -----
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
func.func @two_allocs_vector() {
%0 = memref.alloca() : memref<4xvector<4xf32>, 6>
%1 = memref.alloca() : memref<2xvector<2xi32>, 6>
return
}
}
// CHECK-LABEL: func @two_allocs_vector
// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<2 x vector<2xi32>, stride=8>)>, Function>
// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16>)>, Function>
// -----
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
// CHECK-LABEL: func @alloc_dynamic_size
func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
// CHECK: memref.alloca
%0 = memref.alloca(%arg0) : memref<4x?xf32, 6>
%1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 6>
return %1: f32
}
}
// -----
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
// CHECK-LABEL: func @alloc_unsupported_memory_space
func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 {
// CHECK: memref.alloca
%0 = memref.alloca() : memref<4x5xf32>
%1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32>
return %1: f32
}
}