forked from OSchip/llvm-project
[mlir][spirv] Convert memref.alloca to spv.Variable
Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D124542
This commit is contained in:
parent
72959f7714
commit
8854b73606
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -85,15 +86,27 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
|
|||
offset);
|
||||
}
|
||||
|
||||
/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
|
||||
static bool isAllocationSupported(MemRefType t) {
|
||||
// Currently only support workgroup local memory allocations with static
|
||||
// shape and int or float or vector of int or float element type.
|
||||
if (!(t.hasStaticShape() &&
|
||||
SPIRVTypeConverter::getMemorySpaceForStorageClass(
|
||||
spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt()))
|
||||
/// Returns true if the allocations of memref `type` generated from `allocOp`
|
||||
/// can be lowered to SPIR-V.
|
||||
static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
|
||||
if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
|
||||
if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
|
||||
spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt())
|
||||
return false;
|
||||
} else if (isa<memref::AllocaOp>(allocOp)) {
|
||||
if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
|
||||
spirv::StorageClass::Function) != type.getMemorySpaceAsInt())
|
||||
return false;
|
||||
} else {
|
||||
return false;
|
||||
Type elementType = t.getElementType();
|
||||
}
|
||||
|
||||
// Currently only support static shape and int or float or vector of int or
|
||||
// float element type.
|
||||
if (!type.hasStaticShape())
|
||||
return false;
|
||||
|
||||
Type elementType = type.getElementType();
|
||||
if (auto vecType = elementType.dyn_cast<VectorType>())
|
||||
elementType = vecType.getElementType();
|
||||
return elementType.isIntOrFloat();
|
||||
|
@ -102,10 +115,10 @@ static bool isAllocationSupported(MemRefType t) {
|
|||
/// Returns the scope to use for atomic operations use for emulating store
|
||||
/// operations of unsupported integer bitwidths, based on the memref
|
||||
/// type. Returns None on failure.
|
||||
static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
|
||||
static Optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
|
||||
Optional<spirv::StorageClass> storageClass =
|
||||
SPIRVTypeConverter::getStorageClassForMemorySpace(
|
||||
t.getMemorySpaceAsInt());
|
||||
type.getMemorySpaceAsInt());
|
||||
if (!storageClass)
|
||||
return {};
|
||||
switch (*storageClass) {
|
||||
|
@ -149,6 +162,16 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
|
|||
|
||||
namespace {
|
||||
|
||||
/// Converts memref.alloca to SPIR-V Function variables.
|
||||
class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
|
||||
public:
|
||||
using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts an allocation operation to SPIR-V. Currently only supports lowering
|
||||
/// to Workgroup memory when the size is constant. Note that this pattern needs
|
||||
/// to be applied in a pass that runs at least at spv.module scope since it wil
|
||||
|
@ -215,6 +238,25 @@ public:
|
|||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllocaOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
MemRefType allocType = allocaOp.getType();
|
||||
if (!isAllocationSupported(allocaOp, allocType))
|
||||
return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
|
||||
|
||||
// Get the SPIR-V type for the allocation.
|
||||
Type spirvType = getTypeConverter()->convertType(allocType);
|
||||
rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
|
||||
spirv::StorageClass::Function,
|
||||
/*initializer=*/nullptr);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -223,8 +265,8 @@ LogicalResult
|
|||
AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
MemRefType allocType = operation.getType();
|
||||
if (!isAllocationSupported(allocType))
|
||||
return operation.emitError("unhandled allocation type");
|
||||
if (!isAllocationSupported(operation, allocType))
|
||||
return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
|
||||
|
||||
// Get the SPIR-V type for the allocation.
|
||||
Type spirvType = getTypeConverter()->convertType(allocType);
|
||||
|
@ -262,8 +304,8 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
|
|||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
|
||||
if (!isAllocationSupported(deallocType))
|
||||
return operation.emitError("unhandled deallocation type");
|
||||
if (!isAllocationSupported(operation, deallocType))
|
||||
return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
|
||||
rewriter.eraseOp(operation);
|
||||
return success();
|
||||
}
|
||||
|
@ -505,8 +547,9 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
|
|||
namespace mlir {
|
||||
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
|
||||
IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
|
||||
typeConverter, patterns.getContext());
|
||||
patterns
|
||||
.add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
|
||||
IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -100,10 +100,12 @@ module attributes {
|
|||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
|
||||
}
|
||||
{
|
||||
func.func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : index) {
|
||||
// expected-error @+1 {{unhandled allocation type}}
|
||||
// CHECK-LABEL: func @alloc_dynamic_size
|
||||
func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
|
||||
// CHECK: memref.alloc
|
||||
%0 = memref.alloc(%arg0) : memref<4x?xf32, 3>
|
||||
return
|
||||
%1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 3>
|
||||
return %1: f32
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -114,10 +116,12 @@ module attributes {
|
|||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
|
||||
}
|
||||
{
|
||||
func.func @alloc_dealloc_mem() {
|
||||
// expected-error @+1 {{unhandled allocation type}}
|
||||
// CHECK-LABEL: func @alloc_unsupported_memory_space
|
||||
func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 {
|
||||
// CHECK: memref.alloc
|
||||
%0 = memref.alloc() : memref<4x5xf32>
|
||||
return
|
||||
%1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32>
|
||||
return %1: f32
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -129,8 +133,9 @@ module attributes {
|
|||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
|
||||
}
|
||||
{
|
||||
func.func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : memref<4x?xf32, 3>) {
|
||||
// expected-error @+1 {{unhandled deallocation type}}
|
||||
// CHECK-LABEL: func @dealloc_dynamic_size
|
||||
func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, 3>) {
|
||||
// CHECK: memref.dealloc
|
||||
memref.dealloc %arg0 : memref<4x?xf32, 3>
|
||||
return
|
||||
}
|
||||
|
@ -143,8 +148,9 @@ module attributes {
|
|||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
|
||||
}
|
||||
{
|
||||
func.func @alloc_dealloc_mem(%arg0 : memref<4x5xf32>) {
|
||||
// expected-error @+1 {{unhandled deallocation type}}
|
||||
// CHECK-LABEL: func @dealloc_unsupported_memory_space
|
||||
func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32>) {
|
||||
// CHECK: memref.dealloc
|
||||
memref.dealloc %arg0 : memref<4x5xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
// RUN: mlir-opt -split-input-file -convert-memref-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
|
||||
func.func @alloc_function_variable(%arg0 : index, %arg1 : index) {
|
||||
%0 = memref.alloca() : memref<4x5xf32, 6>
|
||||
%1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, 6>
|
||||
memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, 6>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @alloc_function_variable
|
||||
// CHECK: %[[VAR:.+]] = spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Function>
|
||||
// CHECK: %[[LOADPTR:.+]] = spv.AccessChain %[[VAR]]
|
||||
// CHECK: %[[VAL:.+]] = spv.Load "Function" %[[LOADPTR]] : f32
|
||||
// CHECK: %[[STOREPTR:.+]] = spv.AccessChain %[[VAR]]
|
||||
// CHECK: spv.Store "Function" %[[STOREPTR]], %[[VAL]] : f32
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
|
||||
func.func @two_allocs() {
|
||||
%0 = memref.alloca() : memref<4x5xf32, 6>
|
||||
%1 = memref.alloca() : memref<2x3xi32, 6>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @two_allocs
|
||||
// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<6 x i32, stride=4>)>, Function>
|
||||
// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Function>
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
|
||||
func.func @two_allocs_vector() {
|
||||
%0 = memref.alloca() : memref<4xvector<4xf32>, 6>
|
||||
%1 = memref.alloca() : memref<2xvector<2xi32>, 6>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @two_allocs_vector
|
||||
// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<2 x vector<2xi32>, stride=8>)>, Function>
|
||||
// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16>)>, Function>
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
|
||||
// CHECK-LABEL: func @alloc_dynamic_size
|
||||
func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
|
||||
// CHECK: memref.alloca
|
||||
%0 = memref.alloca(%arg0) : memref<4x?xf32, 6>
|
||||
%1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 6>
|
||||
return %1: f32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
|
||||
// CHECK-LABEL: func @alloc_unsupported_memory_space
|
||||
func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 {
|
||||
// CHECK: memref.alloca
|
||||
%0 = memref.alloca() : memref<4x5xf32>
|
||||
%1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32>
|
||||
return %1: f32
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue