From 60d541e1b9dc7217a0744ede6a582c46795091fc Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 25 Jan 2020 09:16:29 -0500 Subject: [PATCH] [mlir][spirv] Relax verification to allow flexible placement Thus far certain SPIR-V ops have been required to be in spv.module. While this provides strong verification to catch unexpected errors, it's quite rigid and makes progressive lowering difficult. Sometimes we would like to partially lower ops from other dialects, which may involve creating ops like global variables that should be placed in other module-like ops. So this commit relaxes the requirement of such SPIR-V ops' scope to module-like ops. Similarly for function- like ops. Differential Revision: https://reviews.llvm.org/D73415 --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 12 ++--- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 45 ++++++++++++------- mlir/test/Dialect/SPIRV/control-flow-ops.mlir | 35 ++++++++++----- mlir/test/Dialect/SPIRV/structure-ops.mlir | 43 ++++++++++++++++-- 4 files changed, 98 insertions(+), 37 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 3026350fa75d..21d8e658b52e 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -2997,15 +2997,15 @@ def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>; // SPIR-V OpTrait definitions //===----------------------------------------------------------------------===// -// Check that an op can only be used within the scope of a FuncOp. +// Check that an op can only be used within the scope of a function-like op. def InFunctionScope : PredOpTrait< - "op must appear in a 'func' block", - CPred<"($_op.getParentOfType())">>; + "op must appear in a function-like op's block", + CPred<"isNestedInFunctionLikeOp($_op.getParentOp())">>; -// Check that an op can only be used within the scope of a SPIR-V ModuleOp. +// Check that an op can only be used within the scope of a module-like op. def InModuleScope : PredOpTrait< - "op must appear in a 'spv.module' block", - CPred<"llvm::isa_and_nonnull($_op.getParentOp())">>; + "op must appear in a module-like op's block", + CPred<"isDirectInModuleLikeOp($_op.getParentOp())">>; //===----------------------------------------------------------------------===// // SPIR-V opcode specification diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 5e2fede5f344..7854328ceeb2 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -55,6 +55,24 @@ static constexpr const char kVariableAttrName[] = "variable"; // Common utility functions //===----------------------------------------------------------------------===// +/// Returns true if the given op is a function-like op or nested in a +/// function-like op without a module-like op in the middle. +static bool isNestedInFunctionLikeOp(Operation *op) { + if (!op) + return false; + if (op->hasTrait()) + return false; + if (op->hasTrait()) + return true; + return isNestedInFunctionLikeOp(op->getParentOp()); +} + +/// Returns true if the given op is an module-like op that maintains a symbol +/// table. +static bool isDirectInModuleLikeOp(Operation *op) { + return op && op->hasTrait(); +} + static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) { auto constOp = dyn_cast_or_null(op); if (!constOp) { @@ -872,9 +890,9 @@ static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter &printer) { } static LogicalResult verify(spirv::AddressOfOp addressOfOp) { - auto moduleOp = addressOfOp.getParentOfType(); - auto varOp = - moduleOp.lookupSymbol(addressOfOp.variable()); + auto varOp = dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(addressOfOp.getParentOp(), + addressOfOp.variable())); if (!varOp) { return addressOfOp.emitOpError("expected spv.globalVariable symbol"); } @@ -1679,16 +1697,11 @@ static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) { static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { auto fnName = functionCallOp.callee(); - auto moduleOp = functionCallOp.getParentOfType(); - if (!moduleOp) { - return functionCallOp.emitOpError( - "must appear in a function inside 'spv.module'"); - } - - auto funcOp = moduleOp.lookupSymbol(fnName); + auto funcOp = dyn_cast_or_null(SymbolTable::lookupNearestSymbolFrom( + functionCallOp.getParentOp(), fnName)); if (!funcOp) { return functionCallOp.emitOpError("callee function '") - << fnName << "' not found in 'spv.module'"; + << fnName << "' not found in nearest symbol table"; } auto functionType = funcOp.getType(); @@ -1837,8 +1850,8 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) { if (auto init = varOp.getAttrOfType(kInitializerAttrName)) { - auto moduleOp = varOp.getParentOfType(); - auto *initOp = moduleOp.lookupSymbol(init.getValue()); + Operation *initOp = SymbolTable::lookupNearestSymbolFrom( + varOp.getParentOp(), init.getValue()); // TODO: Currently only variable initialization with specialization // constants and other variables is supported. They could be normal // constants in the module scope as well. @@ -2534,9 +2547,9 @@ static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter &printer) { } static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { - auto moduleOp = referenceOfOp.getParentOfType(); - auto specConstOp = - moduleOp.lookupSymbol(referenceOfOp.spec_const()); + auto specConstOp = dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(), + referenceOfOp.spec_const())); if (!specConstOp) { return referenceOfOp.emitOpError("expected spv.specConstant symbol"); } diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir index 63e214a08b9c..1abeafec1861 100644 --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -186,6 +186,19 @@ spv.module "Logical" "GLSL450" { // ----- +// Allow calling functions in other module-like ops +func @callee() { + spv.Return +} + +func @caller() { + // CHECK: spv.FunctionCall + spv.FunctionCall @callee() : () -> () + spv.Return +} + +// ----- + spv.module "Logical" "GLSL450" { func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () { // expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}} @@ -239,7 +252,7 @@ spv.module "Logical" "GLSL450" { spv.module "Logical" "GLSL450" { func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 { - // expected-error @+1 {{op callee function 'f_undefined' not found in 'spv.module'}} + // expected-error @+1 {{op callee function 'f_undefined' not found in nearest symbol table}} %0 = spv.FunctionCall @f_undefined(%arg0, %arg0) : (i32, i32) -> i32 spv.Return } @@ -247,14 +260,6 @@ spv.module "Logical" "GLSL450" { // ----- -func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 { - // expected-error @+1 {{must appear in a function inside 'spv.module'}} - %0 = spv.FunctionCall @f_foo(%arg0, %arg0) : (i32, i32) -> i32 - spv.Return -} - -// ----- - //===----------------------------------------------------------------------===// // spv.loop //===----------------------------------------------------------------------===// @@ -497,8 +502,16 @@ func @in_loop(%cond : i1) -> () { // ----- +// CHECK-LABEL: in_other_func_like_op +func @in_other_func_like_op() { + // CHECK: spv.Return + spv.Return +} + +// ----- + "foo.function"() ({ - // expected-error @+1 {{op must appear in a 'func' block}} + // expected-error @+1 {{op must appear in a function-like op's block}} spv.Return }) : () -> () @@ -562,7 +575,7 @@ func @in_loop(%cond : i1) -> (i32) { "foo.function"() ({ %0 = spv.constant true - // expected-error @+1 {{op must appear in a 'func' block}} + // expected-error @+1 {{op must appear in a function-like op's block}} spv.ReturnValue %0 : i1 }) : () -> () diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir index 8fe03f46323f..2ba1023f167b 100644 --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -18,6 +18,16 @@ spv.module "Logical" "GLSL450" { // ----- +// Allow taking address of global variables in other module-like ops +spv.globalVariable @var : !spv.ptr>, Input> +func @address_of() -> () { + // CHECK: spv._address_of @var + %1 = spv._address_of @var : !spv.ptr>, Input> + return +} + +// ----- + spv.module "Logical" "GLSL450" { spv.globalVariable @var1 : !spv.ptr>, Input> func @foo() -> () { @@ -174,7 +184,7 @@ spv.module "Logical" "GLSL450" { spv.module "Logical" "GLSL450" { func @do_nothing() -> () { - // expected-error @+1 {{'spv.EntryPoint' op failed to verify that op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv.EntryPoint "GLCompute" @do_something } } @@ -229,6 +239,13 @@ spv.module "Logical" "GLSL450" { // ----- +// Allow initializers coming from other module-like ops +spv.specConstant @sc = 4.0 : f32 +// CHECK: spv.globalVariable @var initializer(@sc) +spv.globalVariable @var initializer(@sc) : !spv.ptr + +// ----- + spv.module "Logical" "GLSL450" { // CHECK: spv.globalVariable @var0 bind(1, 2) : !spv.ptr spv.globalVariable @var0 bind(1, 2) : !spv.ptr @@ -252,6 +269,14 @@ spv.module "Logical" "GLSL450" { // ----- +// Allow in other module-like ops +module { + // CHECK: spv.globalVariable + spv.globalVariable @var0 : !spv.ptr +} + +// ----- + spv.module "Logical" "GLSL450" { // expected-error @+1 {{expected spv.ptr type}} spv.globalVariable @var0 : f32 @@ -275,7 +300,7 @@ spv.module "Logical" "GLSL450" { spv.module "Logical" "GLSL450" { func @foo() { - // expected-error @+1 {{op failed to verify that op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv.globalVariable @var0 : !spv.ptr spv.Return } @@ -418,7 +443,7 @@ spv.module "Logical" "GLSL450" { //===----------------------------------------------------------------------===// func @module_end_not_in_module() -> () { - // expected-error @+1 {{op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv._module_end } @@ -461,6 +486,16 @@ spv.module "Logical" "GLSL450" { // ----- +// Allow taking reference of spec constant in other module-like ops +spv.specConstant @sc = 5 : i32 +func @reference_of() { + // CHECK: spv._reference_of @sc + %0 = spv._reference_of @sc : i32 + return +} + +// ----- + spv.module "Logical" "GLSL450" { func @foo() -> () { // expected-error @+1 {{expected spv.specConstant symbol}} @@ -519,7 +554,7 @@ spv.module "Logical" "GLSL450" { // ----- func @use_in_function() -> () { - // expected-error @+1 {{op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv.specConstant @sc = false return }