forked from OSchip/llvm-project
[MLIR][SPIRVToLLVM] Additional conversions for spirv-runner
This patch adds more op/type conversion support necessary for `spirv-runner`: - EntryPoint/ExecutionMode: currently removed since we assume having only one kernel function in the kernel module. - StorageBuffer storage class is now supported. We are not concerned with multithreading so this is fine for now. - Type conversion enhanced, now regular offsets and strides for structs and arrays are supported (based on `VulkanLayoutUtils`). - Support of `spc.AccessChain` that is modelled with GEP op in LLVM dialect. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D86109
This commit is contained in:
parent
bb54bcf849
commit
cc98a0fbe4
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
@ -179,6 +180,22 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
|
|||
return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
|
||||
}
|
||||
|
||||
/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
|
||||
/// offset to LLVM struct. Otherwise, the conversion is not supported.
|
||||
static Optional<Type>
|
||||
convertStructTypeWithOffset(spirv::StructType type,
|
||||
LLVMTypeConverter &converter) {
|
||||
if (type != VulkanLayoutUtils::decorateType(type))
|
||||
return llvm::None;
|
||||
|
||||
auto elementsVector = llvm::to_vector<8>(
|
||||
llvm::map_range(type.getElementTypes(), [&](Type elementType) {
|
||||
return converter.convertType(elementType).cast<LLVM::LLVMType>();
|
||||
}));
|
||||
return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
|
||||
/*isPacked=*/false);
|
||||
}
|
||||
|
||||
/// Converts SPIR-V struct with no offset to packed LLVM struct.
|
||||
static Type convertStructTypePacked(spirv::StructType type,
|
||||
LLVMTypeConverter &converter) {
|
||||
|
@ -223,16 +240,22 @@ static LogicalResult replaceWithLoadOrStore(Operation *op,
|
|||
// Type conversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Converts SPIR-V array type to LLVM array. There is no modelling of array
|
||||
/// stride at the moment.
|
||||
/// Converts SPIR-V array type to LLVM array. Natural stride (according to
|
||||
/// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
|
||||
/// when converting ops that manipulate array types.
|
||||
static Optional<Type> convertArrayType(spirv::ArrayType type,
|
||||
TypeConverter &converter) {
|
||||
if (type.getArrayStride() != 0)
|
||||
unsigned stride = type.getArrayStride();
|
||||
Type elementType = type.getElementType();
|
||||
auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
|
||||
if (stride != 0 &&
|
||||
!(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
|
||||
return llvm::None;
|
||||
auto elementType =
|
||||
converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
|
||||
|
||||
auto llvmElementType =
|
||||
converter.convertType(elementType).cast<LLVM::LLVMType>();
|
||||
unsigned numElements = type.getNumElements();
|
||||
return LLVM::LLVMType::getArrayTy(elementType, numElements);
|
||||
return LLVM::LLVMType::getArrayTy(llvmElementType, numElements);
|
||||
}
|
||||
|
||||
/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
|
||||
|
@ -257,13 +280,15 @@ static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
|
|||
}
|
||||
|
||||
/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
|
||||
/// member decorations or with offset.
|
||||
/// member decorations. Also, only natural offset is supported.
|
||||
static Optional<Type> convertStructType(spirv::StructType type,
|
||||
LLVMTypeConverter &converter) {
|
||||
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
|
||||
type.getMemberDecorations(memberDecorations);
|
||||
if (type.hasOffset() || !memberDecorations.empty())
|
||||
if (!memberDecorations.empty())
|
||||
return llvm::None;
|
||||
if (type.hasOffset())
|
||||
return convertStructTypeWithOffset(type, converter);
|
||||
return convertStructTypePacked(type, converter);
|
||||
}
|
||||
|
||||
|
@ -273,6 +298,31 @@ static Optional<Type> convertStructType(spirv::StructType type,
|
|||
|
||||
namespace {
|
||||
|
||||
class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
|
||||
public:
|
||||
using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = typeConverter.convertType(op.component_ptr().getType());
|
||||
if (!dstType)
|
||||
return failure();
|
||||
// To use GEP we need to add a first 0 index to go through the pointer.
|
||||
auto indices = llvm::to_vector<4>(op.indices());
|
||||
Type indexType = op.indices().front().getType();
|
||||
auto llvmIndexType = typeConverter.convertType(indexType);
|
||||
if (!llvmIndexType)
|
||||
return failure();
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
|
||||
indices.insert(indices.begin(), zero);
|
||||
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(),
|
||||
indices);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
|
||||
public:
|
||||
using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
|
||||
|
@ -545,11 +595,14 @@ public:
|
|||
if (!dstType)
|
||||
return failure();
|
||||
|
||||
// Limit conversion to the current invocation only for now.
|
||||
// Limit conversion to the current invocation only or `StorageBuffer`
|
||||
// required by SPIR-V runner.
|
||||
// This is okay because multiple invocations are not supported yet.
|
||||
auto storageClass = srcType.getStorageClass();
|
||||
if (storageClass != spirv::StorageClass::Input &&
|
||||
storageClass != spirv::StorageClass::Private &&
|
||||
storageClass != spirv::StorageClass::Output) {
|
||||
storageClass != spirv::StorageClass::Output &&
|
||||
storageClass != spirv::StorageClass::StorageBuffer) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -757,6 +810,20 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// A template pattern that erases the given `SPIRVOp`.
|
||||
template <typename SPIRVOp>
|
||||
class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
|
||||
public:
|
||||
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
|
||||
public:
|
||||
using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
|
||||
|
@ -875,18 +942,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class MergePattern : public SPIRVToLLVMConversion<spirv::MergeOp> {
|
||||
public:
|
||||
using SPIRVToLLVMConversion<spirv::MergeOp>::SPIRVToLLVMConversion;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::MergeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts `spv.selection` with `spv.BranchConditional` in its header block.
|
||||
/// All blocks within selection should be reachable for conversion to succeed.
|
||||
class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
|
||||
|
@ -1266,11 +1321,18 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
|
|||
ConstantScalarAndVectorPattern,
|
||||
|
||||
// Control Flow ops
|
||||
BranchConversionPattern, BranchConditionalConversionPattern, LoopPattern,
|
||||
SelectionPattern, MergePattern,
|
||||
BranchConversionPattern, BranchConditionalConversionPattern,
|
||||
FunctionCallPattern, LoopPattern, SelectionPattern,
|
||||
ErasePattern<spirv::MergeOp>,
|
||||
|
||||
// Entry points and execution mode
|
||||
// Module generated from SPIR-V could have other "internal" functions, so
|
||||
// having entry point and execution mode metadat can be useful. For now,
|
||||
// simply remove them.
|
||||
// TODO: Support EntryPoint/ExecutionMode properly.
|
||||
ErasePattern<spirv::EntryPointOp>, ErasePattern<spirv::ExecutionModeOp>,
|
||||
|
||||
// Function Call op
|
||||
FunctionCallPattern,
|
||||
|
||||
// GLSL extended instruction set ops
|
||||
DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
|
||||
|
@ -1295,8 +1357,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
|
|||
NotPattern<spirv::LogicalNotOp>,
|
||||
|
||||
// Memory ops
|
||||
AddressOfPattern, GlobalVariablePattern, LoadStorePattern<spirv::LoadOp>,
|
||||
LoadStorePattern<spirv::StoreOp>, VariablePattern,
|
||||
AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
|
||||
LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
|
||||
VariablePattern,
|
||||
|
||||
// Miscellaneous ops
|
||||
DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
|
||||
|
|
|
@ -1,5 +1,30 @@
|
|||
// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.AccessChain
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @access_chain
|
||||
func @access_chain() -> () {
|
||||
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||
%0 = spv.constant 1: i32
|
||||
%1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
|
||||
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr<struct<packed (float, array<4 x float>)>>, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm.ptr<float>
|
||||
%2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>, i32, i32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @access_chain_array
|
||||
func @access_chain_array(%arg0 : i32) -> () {
|
||||
%0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
|
||||
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %{{.*}}] : (!llvm.ptr<array<4 x array<4 x float>>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<array<4 x float>>
|
||||
%1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32
|
||||
%2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.globalVariable and spv._address_of
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -20,6 +20,23 @@ func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) {
|
|||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.EntryPoint and spv.ExecutionMode
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: llvm.func @empty
|
||||
// CHECK-NEXT: llvm.return
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @empty() -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
spv.EntryPoint "GLCompute" @empty
|
||||
spv.ExecutionMode @empty "LocalSize", 1, 1, 1
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Undef
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,21 +1,14 @@
|
|||
// RUN: mlir-opt %s -convert-spirv-to-llvm -verify-diagnostics -split-input-file
|
||||
|
||||
// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}}
|
||||
spv.func @array_with_stride(%arg: !spv.array<4 x f32, stride=4>) -> () "None" {
|
||||
spv.func @array_with_unnatural_stride(%arg: !spv.array<4 x f32, stride=8>) -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}}
|
||||
spv.func @struct_with_offset1(%arg: !spv.struct<i32[0], i32[4]>) -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}}
|
||||
spv.func @struct_with_offset2(%arg: !spv.struct<i32[0], i32[8]>) -> () "None" {
|
||||
spv.func @struct_with_unnatural_offset(%arg: !spv.struct<i32[0], i32[8]>) -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,10 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @array(!llvm.array<16 x float>, !llvm.array<32 x vec<4 x float>>)
|
||||
func @array(!spv.array<16xf32>, !spv.array< 32 x vector<4xf32> >) -> ()
|
||||
func @array(!spv.array<16 x f32>, !spv.array< 32 x vector<4xf32> >) -> ()
|
||||
|
||||
// CHECK-LABEL: @array_with_natural_stride(!llvm.array<16 x float>)
|
||||
func @array_with_natural_stride(!spv.array<16 x f32, stride=4>) -> ()
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pointer type
|
||||
|
@ -36,3 +39,6 @@ func @struct(!spv.struct<f64>) -> ()
|
|||
|
||||
// CHECK-LABEL: @struct_nested(!llvm.struct<packed (i32, struct<packed (i64, i32)>)>)
|
||||
func @struct_nested(!spv.struct<i32, !spv.struct<i64, i32>>)
|
||||
|
||||
// CHECK-LABEL: @struct_with_natural_offset(!llvm.struct<(i8, i32)>)
|
||||
func @struct_with_natural_offset(!spv.struct<i8[0], i32[4]>) -> ()
|
||||
|
|
Loading…
Reference in New Issue