forked from OSchip/llvm-project
[spirv] Implement inliner interface
We just need to implement a few interface hooks to DialectInlinerInterface and CallOpInterface to gain the benefits of an inliner. :) Right now only supports some trivial cases: * Inlining single block with spv.Return/spv.ReturnValue * Inlining multi block with spv.Return * Inlining spv.selection/spv.loop without return ops More advanced cases will require block argument and Phi support. PiperOrigin-RevId: 275151132
This commit is contained in:
parent
1ba9bb0507
commit
0e3efb32c6
|
@ -29,6 +29,11 @@
|
|||
include "mlir/SPIRV/SPIRVBase.td"
|
||||
#endif // SPIRV_BASE
|
||||
|
||||
#ifdef MLIR_CALLINTERFACES
|
||||
#else
|
||||
include "mlir/Analysis/CallInterfaces.td"
|
||||
#endif // MLIR_CALLINTERFACES
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
|
||||
|
@ -151,7 +156,8 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [InFunctionScope]> {
|
||||
def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [
|
||||
InFunctionScope, DeclareOpInterfaceMethods<CallOpInterface>]> {
|
||||
let summary = "Call a function.";
|
||||
|
||||
let description = [{
|
||||
|
|
|
@ -264,7 +264,8 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope]> {
|
|||
}
|
||||
|
||||
def SPV_ModuleOp : SPV_Op<"module",
|
||||
[SingleBlockImplicitTerminator<"ModuleEndOp">,
|
||||
[IsolatedFromAbove,
|
||||
SingleBlockImplicitTerminator<"ModuleEndOp">,
|
||||
NativeOpTrait<"SymbolTable">]> {
|
||||
let summary = "The top-level op that defines a SPIR-V module";
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Support/StringExtras.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
@ -34,6 +35,67 @@ namespace spirv {
|
|||
using namespace mlir;
|
||||
using namespace mlir::spirv;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InlinerInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
|
||||
static inline bool containsReturn(Region ®ion) {
|
||||
return llvm::any_of(region, [](Block &block) {
|
||||
Operation *terminator = block.getTerminator();
|
||||
return isa<spirv::ReturnOp>(terminator) ||
|
||||
isa<spirv::ReturnValueOp>(terminator);
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// This class defines the interface for inlining within the SPIR-V dialect.
|
||||
struct SPIRVInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
/// Returns true if the given region 'src' can be inlined into the region
|
||||
/// 'dest' that is attached to an operation registered to the current dialect.
|
||||
bool isLegalToInline(Operation *op, Region *dest,
|
||||
BlockAndValueMapping &) const final {
|
||||
// TODO(antiagainst): Enable inlining structured control flows with return.
|
||||
if ((isa<spirv::SelectionOp>(op) || isa<spirv::LoopOp>(op)) &&
|
||||
containsReturn(op->getRegion(0)))
|
||||
return false;
|
||||
// TODO(antiagainst): we need to filter OpKill here to avoid inlining it to
|
||||
// a loop continue construct:
|
||||
// https://github.com/KhronosGroup/SPIRV-Headers/issues/86
|
||||
// However OpKill is fragment shader specific and we don't support it yet.
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Handle the given inlined terminator by replacing it with a new operation
|
||||
/// as necessary.
|
||||
void handleTerminator(Operation *op, Block *newDest) const final {
|
||||
if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
|
||||
OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
|
||||
op->erase();
|
||||
} else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
|
||||
llvm_unreachable("unimplemented spv.ReturnValue in inliner");
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle the given inlined terminator by replacing it with a new operation
|
||||
/// as necessary.
|
||||
void handleTerminator(Operation *op,
|
||||
ArrayRef<Value *> valuesToRepl) const final {
|
||||
// Only spv.ReturnValue needs to be handled here.
|
||||
auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
|
||||
if (!retValOp)
|
||||
return;
|
||||
|
||||
// Replace the values directly with the return operands.
|
||||
assert(valuesToRepl.size() == 1 &&
|
||||
"spv.ReturnValue expected to only handle one result");
|
||||
valuesToRepl.front()->replaceAllUsesWith(retValOp.value());
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -48,6 +110,8 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context)
|
|||
#include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
|
||||
>();
|
||||
|
||||
addInterfaces<SPIRVInlinerInterface>();
|
||||
|
||||
// Allow unknown operations because SPIR-V is extensible.
|
||||
allowUnknownOperations();
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
|
||||
#include "mlir/Analysis/CallInterfaces.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -1199,6 +1200,14 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
|
||||
return getAttrOfType<SymbolRefAttr>(kCallee);
|
||||
}
|
||||
|
||||
Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
|
||||
return arguments();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.globalVariable
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
// RUN: mlir-opt %s -split-input-file -pass-pipeline='spv.module(inline)' -mlir-disable-inline-simplify | FileCheck %s
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @callee() {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @calling_single_block_ret_func
|
||||
func @calling_single_block_ret_func() {
|
||||
// CHECK-NEXT: spv.Return
|
||||
spv.FunctionCall @callee() : () -> ()
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @callee() -> i32 {
|
||||
%0 = spv.constant 42 : i32
|
||||
spv.ReturnValue %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @calling_single_block_retval_func
|
||||
func @calling_single_block_retval_func() -> i32 {
|
||||
// CHECK-NEXT: %[[CST:.*]] = spv.constant 42
|
||||
%0 = spv.FunctionCall @callee() : () -> (i32)
|
||||
// CHECK-NEXT: spv.ReturnValue %[[CST]]
|
||||
spv.ReturnValue %0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
spv.globalVariable @data bind(0, 0) : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>
|
||||
func @callee() {
|
||||
%0 = spv._address_of @data : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>
|
||||
%1 = spv.constant 0: i32
|
||||
%2 = spv.AccessChain %0[%1, %1] : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>
|
||||
spv.Branch ^next
|
||||
|
||||
^next:
|
||||
%3 = spv.constant 42: i32
|
||||
spv.Store "StorageBuffer" %2, %3 : i32
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @calling_multi_block_ret_func
|
||||
func @calling_multi_block_ret_func() {
|
||||
// CHECK-NEXT: spv._address_of
|
||||
// CHECK-NEXT: spv.constant 0
|
||||
// CHECK-NEXT: spv.AccessChain
|
||||
// CHECK-NEXT: spv.Branch ^bb1
|
||||
// CHECK-NEXT: ^bb1:
|
||||
// CHECK-NEXT: spv.constant
|
||||
// CHECK-NEXT: spv.Store
|
||||
// CHECK-NEXT: spv.Branch ^bb2
|
||||
spv.FunctionCall @callee() : () -> ()
|
||||
// CHECK-NEXT: ^bb2:
|
||||
// CHECK-NEXT: spv.Return
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: calling_multi_block_retval_func
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @callee(%cond : i1) -> () {
|
||||
spv.selection {
|
||||
spv.BranchConditional %cond, ^then, ^merge
|
||||
^then:
|
||||
spv.Return
|
||||
^merge:
|
||||
spv._merge
|
||||
}
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: calling_selection_ret_func
|
||||
func @calling_selection_ret_func() {
|
||||
%0 = spv.constant true
|
||||
// CHECK: spv.FunctionCall
|
||||
spv.FunctionCall @callee(%0) : (i1) -> ()
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @callee(%cond : i1) -> () {
|
||||
spv.selection {
|
||||
spv.BranchConditional %cond, ^then, ^merge
|
||||
^then:
|
||||
spv.Branch ^merge
|
||||
^merge:
|
||||
spv._merge
|
||||
}
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: calling_selection_no_ret_func
|
||||
func @calling_selection_no_ret_func() {
|
||||
// CHECK-NEXT: %[[TRUE:.*]] = spv.constant true
|
||||
%0 = spv.constant true
|
||||
// CHECK-NEXT: spv.selection
|
||||
// CHECK-NEXT: spv.BranchConditional %[[TRUE]], ^bb1, ^bb2
|
||||
// CHECK-NEXT: ^bb1:
|
||||
// CHECK-NEXT: spv.Branch ^bb2
|
||||
// CHECK-NEXT: ^bb2:
|
||||
// CHECK-NEXT: spv._merge
|
||||
spv.FunctionCall @callee(%0) : (i1) -> ()
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @callee(%cond : i1) -> () {
|
||||
spv.loop {
|
||||
spv.Branch ^header
|
||||
^header:
|
||||
spv.BranchConditional %cond, ^body, ^merge
|
||||
^body:
|
||||
spv.Return
|
||||
^continue:
|
||||
spv.Branch ^header
|
||||
^merge:
|
||||
spv._merge
|
||||
}
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: calling_loop_ret_func
|
||||
func @calling_loop_ret_func() {
|
||||
%0 = spv.constant true
|
||||
// CHECK: spv.FunctionCall
|
||||
spv.FunctionCall @callee(%0) : (i1) -> ()
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @callee(%cond : i1) -> () {
|
||||
spv.loop {
|
||||
spv.Branch ^header
|
||||
^header:
|
||||
spv.BranchConditional %cond, ^body, ^merge
|
||||
^body:
|
||||
spv.Branch ^continue
|
||||
^continue:
|
||||
spv.Branch ^header
|
||||
^merge:
|
||||
spv._merge
|
||||
}
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: calling_loop_no_ret_func
|
||||
func @calling_loop_no_ret_func() {
|
||||
// CHECK-NEXT: %[[TRUE:.*]] = spv.constant true
|
||||
%0 = spv.constant true
|
||||
// CHECK-NEXT: spv.loop
|
||||
// CHECK-NEXT: spv.Branch ^bb1
|
||||
// CHECK-NEXT: ^bb1:
|
||||
// CHECK-NEXT: spv.BranchConditional %[[TRUE]], ^bb2, ^bb4
|
||||
// CHECK-NEXT: ^bb2:
|
||||
// CHECK-NEXT: spv.Branch ^bb3
|
||||
// CHECK-NEXT: ^bb3:
|
||||
// CHECK-NEXT: spv.Branch ^bb1
|
||||
// CHECK-NEXT: ^bb4:
|
||||
// CHECK-NEXT: spv._merge
|
||||
spv.FunctionCall @callee(%0) : (i1) -> ()
|
||||
spv.Return
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue