[MLIR][SPIRVToLLVM] Additional conversions for spirv-runner

This patch adds more op/type conversion support
necessary for `spirv-runner`:
- EntryPoint/ExecutionMode: currently removed since we assume
having only one kernel function in the kernel module.
- StorageBuffer storage class is now supported. We are not
concerned with multithreading so this is fine for now.
- Type conversion enhanced, now regular offsets and strides
for structs and arrays are supported (based on
`VulkanLayoutUtils`).
- Support of `spc.AccessChain` that is modelled with GEP op
in LLVM dialect.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D86109
This commit is contained in:
George Mitenkov 2020-08-18 18:42:23 +03:00
parent bb54bcf849
commit cc98a0fbe4
5 changed files with 141 additions and 37 deletions

View File

@ -14,6 +14,7 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -179,6 +180,22 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
}
/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
/// offset to LLVM struct. Otherwise, the conversion is not supported.
static Optional<Type>
convertStructTypeWithOffset(spirv::StructType type,
LLVMTypeConverter &converter) {
if (type != VulkanLayoutUtils::decorateType(type))
return llvm::None;
auto elementsVector = llvm::to_vector<8>(
llvm::map_range(type.getElementTypes(), [&](Type elementType) {
return converter.convertType(elementType).cast<LLVM::LLVMType>();
}));
return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
/*isPacked=*/false);
}
/// Converts SPIR-V struct with no offset to packed LLVM struct.
static Type convertStructTypePacked(spirv::StructType type,
LLVMTypeConverter &converter) {
@ -223,16 +240,22 @@ static LogicalResult replaceWithLoadOrStore(Operation *op,
// Type conversion
//===----------------------------------------------------------------------===//
/// Converts SPIR-V array type to LLVM array. There is no modelling of array
/// stride at the moment.
/// Converts SPIR-V array type to LLVM array. Natural stride (according to
/// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
/// when converting ops that manipulate array types.
static Optional<Type> convertArrayType(spirv::ArrayType type,
TypeConverter &converter) {
if (type.getArrayStride() != 0)
unsigned stride = type.getArrayStride();
Type elementType = type.getElementType();
auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
if (stride != 0 &&
!(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
return llvm::None;
auto elementType =
converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
auto llvmElementType =
converter.convertType(elementType).cast<LLVM::LLVMType>();
unsigned numElements = type.getNumElements();
return LLVM::LLVMType::getArrayTy(elementType, numElements);
return LLVM::LLVMType::getArrayTy(llvmElementType, numElements);
}
/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
@ -257,13 +280,15 @@ static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
}
/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
/// member decorations or with offset.
/// member decorations. Also, only natural offset is supported.
static Optional<Type> convertStructType(spirv::StructType type,
LLVMTypeConverter &converter) {
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
type.getMemberDecorations(memberDecorations);
if (type.hasOffset() || !memberDecorations.empty())
if (!memberDecorations.empty())
return llvm::None;
if (type.hasOffset())
return convertStructTypeWithOffset(type, converter);
return convertStructTypePacked(type, converter);
}
@ -273,6 +298,31 @@ static Optional<Type> convertStructType(spirv::StructType type,
namespace {
class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
public:
using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(op.component_ptr().getType());
if (!dstType)
return failure();
// To use GEP we need to add a first 0 index to go through the pointer.
auto indices = llvm::to_vector<4>(op.indices());
Type indexType = op.indices().front().getType();
auto llvmIndexType = typeConverter.convertType(indexType);
if (!llvmIndexType)
return failure();
Value zero = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
indices.insert(indices.begin(), zero);
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(),
indices);
return success();
}
};
class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
public:
using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
@ -545,11 +595,14 @@ public:
if (!dstType)
return failure();
// Limit conversion to the current invocation only for now.
// Limit conversion to the current invocation only or `StorageBuffer`
// required by SPIR-V runner.
// This is okay because multiple invocations are not supported yet.
auto storageClass = srcType.getStorageClass();
if (storageClass != spirv::StorageClass::Input &&
storageClass != spirv::StorageClass::Private &&
storageClass != spirv::StorageClass::Output) {
storageClass != spirv::StorageClass::Output &&
storageClass != spirv::StorageClass::StorageBuffer) {
return failure();
}
@ -757,6 +810,20 @@ public:
}
};
/// A template pattern that erases the given `SPIRVOp`.
template <typename SPIRVOp>
class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
public:
using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
@ -875,18 +942,6 @@ public:
}
};
class MergePattern : public SPIRVToLLVMConversion<spirv::MergeOp> {
public:
using SPIRVToLLVMConversion<spirv::MergeOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::MergeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
/// Converts `spv.selection` with `spv.BranchConditional` in its header block.
/// All blocks within selection should be reachable for conversion to succeed.
class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
@ -1266,11 +1321,18 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
ConstantScalarAndVectorPattern,
// Control Flow ops
BranchConversionPattern, BranchConditionalConversionPattern, LoopPattern,
SelectionPattern, MergePattern,
BranchConversionPattern, BranchConditionalConversionPattern,
FunctionCallPattern, LoopPattern, SelectionPattern,
ErasePattern<spirv::MergeOp>,
// Entry points and execution mode
// Module generated from SPIR-V could have other "internal" functions, so
// having entry point and execution mode metadat can be useful. For now,
// simply remove them.
// TODO: Support EntryPoint/ExecutionMode properly.
ErasePattern<spirv::EntryPointOp>, ErasePattern<spirv::ExecutionModeOp>,
// Function Call op
FunctionCallPattern,
// GLSL extended instruction set ops
DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
@ -1295,8 +1357,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
NotPattern<spirv::LogicalNotOp>,
// Memory ops
AddressOfPattern, GlobalVariablePattern, LoadStorePattern<spirv::LoadOp>,
LoadStorePattern<spirv::StoreOp>, VariablePattern,
AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
VariablePattern,
// Miscellaneous ops
DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,

View File

@ -1,5 +1,30 @@
// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
//===----------------------------------------------------------------------===//
// spv.AccessChain
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @access_chain
func @access_chain() -> () {
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
%0 = spv.constant 1: i32
%1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr<struct<packed (float, array<4 x float>)>>, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm.ptr<float>
%2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>, i32, i32
return
}
// CHECK-LABEL: @access_chain_array
func @access_chain_array(%arg0 : i32) -> () {
%0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %{{.*}}] : (!llvm.ptr<array<4 x array<4 x float>>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<array<4 x float>>
%1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32
%2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32>
return
}
//===----------------------------------------------------------------------===//
// spv.globalVariable and spv._address_of
//===----------------------------------------------------------------------===//

View File

@ -20,6 +20,23 @@ func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) {
return
}
//===----------------------------------------------------------------------===//
// spv.EntryPoint and spv.ExecutionMode
//===----------------------------------------------------------------------===//
// CHECK: module {
// CHECK-NEXT: llvm.func @empty
// CHECK-NEXT: llvm.return
// CHECK-NEXT: }
// CHECK-NEXT: }
spv.module Logical GLSL450 {
spv.func @empty() -> () "None" {
spv.Return
}
spv.EntryPoint "GLCompute" @empty
spv.ExecutionMode @empty "LocalSize", 1, 1, 1
}
//===----------------------------------------------------------------------===//
// spv.Undef
//===----------------------------------------------------------------------===//

View File

@ -1,21 +1,14 @@
// RUN: mlir-opt %s -convert-spirv-to-llvm -verify-diagnostics -split-input-file
// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}}
spv.func @array_with_stride(%arg: !spv.array<4 x f32, stride=4>) -> () "None" {
spv.func @array_with_unnatural_stride(%arg: !spv.array<4 x f32, stride=8>) -> () "None" {
spv.Return
}
// -----
// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}}
spv.func @struct_with_offset1(%arg: !spv.struct<i32[0], i32[4]>) -> () "None" {
spv.Return
}
// -----
// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}}
spv.func @struct_with_offset2(%arg: !spv.struct<i32[0], i32[8]>) -> () "None" {
spv.func @struct_with_unnatural_offset(%arg: !spv.struct<i32[0], i32[8]>) -> () "None" {
spv.Return
}

View File

@ -5,7 +5,10 @@
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @array(!llvm.array<16 x float>, !llvm.array<32 x vec<4 x float>>)
func @array(!spv.array<16xf32>, !spv.array< 32 x vector<4xf32> >) -> ()
func @array(!spv.array<16 x f32>, !spv.array< 32 x vector<4xf32> >) -> ()
// CHECK-LABEL: @array_with_natural_stride(!llvm.array<16 x float>)
func @array_with_natural_stride(!spv.array<16 x f32, stride=4>) -> ()
//===----------------------------------------------------------------------===//
// Pointer type
@ -36,3 +39,6 @@ func @struct(!spv.struct<f64>) -> ()
// CHECK-LABEL: @struct_nested(!llvm.struct<packed (i32, struct<packed (i64, i32)>)>)
func @struct_nested(!spv.struct<i32, !spv.struct<i64, i32>>)
// CHECK-LABEL: @struct_with_natural_offset(!llvm.struct<(i8, i32)>)
func @struct_with_natural_offset(!spv.struct<i8[0], i32[4]>) -> ()