[mlir][spirv] Relax verification to allow flexible placement

Thus far certain SPIR-V ops have been required to be in spv.module.
While this provides strong verification to catch unexpected errors,
it's quite rigid and makes progressive lowering difficult. Sometimes
we would like to partially lower ops from other dialects, which may
involve creating ops like global variables that should be placed in
other module-like ops. So this commit relaxes the requirement of
such SPIR-V ops' scope to module-like ops. Similarly for function-
like ops.

Differential Revision: https://reviews.llvm.org/D73415
This commit is contained in:
Lei Zhang 2020-01-25 09:16:29 -05:00
parent ae21e37eb4
commit 60d541e1b9
4 changed files with 98 additions and 37 deletions

View File

@ -2997,15 +2997,15 @@ def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>;
// SPIR-V OpTrait definitions
//===----------------------------------------------------------------------===//
// Check that an op can only be used within the scope of a FuncOp.
// Check that an op can only be used within the scope of a function-like op.
def InFunctionScope : PredOpTrait<
"op must appear in a 'func' block",
CPred<"($_op.getParentOfType<FuncOp>())">>;
"op must appear in a function-like op's block",
CPred<"isNestedInFunctionLikeOp($_op.getParentOp())">>;
// Check that an op can only be used within the scope of a SPIR-V ModuleOp.
// Check that an op can only be used within the scope of a module-like op.
def InModuleScope : PredOpTrait<
"op must appear in a 'spv.module' block",
CPred<"llvm::isa_and_nonnull<spirv::ModuleOp>($_op.getParentOp())">>;
"op must appear in a module-like op's block",
CPred<"isDirectInModuleLikeOp($_op.getParentOp())">>;
//===----------------------------------------------------------------------===//
// SPIR-V opcode specification

View File

@ -55,6 +55,24 @@ static constexpr const char kVariableAttrName[] = "variable";
// Common utility functions
//===----------------------------------------------------------------------===//
/// Returns true if the given op is a function-like op or nested in a
/// function-like op without a module-like op in the middle.
static bool isNestedInFunctionLikeOp(Operation *op) {
if (!op)
return false;
if (op->hasTrait<OpTrait::SymbolTable>())
return false;
if (op->hasTrait<OpTrait::FunctionLike>())
return true;
return isNestedInFunctionLikeOp(op->getParentOp());
}
/// Returns true if the given op is an module-like op that maintains a symbol
/// table.
static bool isDirectInModuleLikeOp(Operation *op) {
return op && op->hasTrait<OpTrait::SymbolTable>();
}
static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
if (!constOp) {
@ -872,9 +890,9 @@ static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter &printer) {
}
static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
auto moduleOp = addressOfOp.getParentOfType<spirv::ModuleOp>();
auto varOp =
moduleOp.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
SymbolTable::lookupNearestSymbolFrom(addressOfOp.getParentOp(),
addressOfOp.variable()));
if (!varOp) {
return addressOfOp.emitOpError("expected spv.globalVariable symbol");
}
@ -1679,16 +1697,11 @@ static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
auto fnName = functionCallOp.callee();
auto moduleOp = functionCallOp.getParentOfType<spirv::ModuleOp>();
if (!moduleOp) {
return functionCallOp.emitOpError(
"must appear in a function inside 'spv.module'");
}
auto funcOp = moduleOp.lookupSymbol<FuncOp>(fnName);
auto funcOp = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
functionCallOp.getParentOp(), fnName));
if (!funcOp) {
return functionCallOp.emitOpError("callee function '")
<< fnName << "' not found in 'spv.module'";
<< fnName << "' not found in nearest symbol table";
}
auto functionType = funcOp.getType();
@ -1837,8 +1850,8 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
if (auto init =
varOp.getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
auto *initOp = moduleOp.lookupSymbol(init.getValue());
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
varOp.getParentOp(), init.getValue());
// TODO: Currently only variable initialization with specialization
// constants and other variables is supported. They could be normal
// constants in the module scope as well.
@ -2534,9 +2547,9 @@ static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter &printer) {
}
static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
auto moduleOp = referenceOfOp.getParentOfType<spirv::ModuleOp>();
auto specConstOp =
moduleOp.lookupSymbol<spirv::SpecConstantOp>(referenceOfOp.spec_const());
auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(
SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(),
referenceOfOp.spec_const()));
if (!specConstOp) {
return referenceOfOp.emitOpError("expected spv.specConstant symbol");
}

View File

@ -186,6 +186,19 @@ spv.module "Logical" "GLSL450" {
// -----
// Allow calling functions in other module-like ops
func @callee() {
spv.Return
}
func @caller() {
// CHECK: spv.FunctionCall
spv.FunctionCall @callee() : () -> ()
spv.Return
}
// -----
spv.module "Logical" "GLSL450" {
func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () {
// expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}}
@ -239,7 +252,7 @@ spv.module "Logical" "GLSL450" {
spv.module "Logical" "GLSL450" {
func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 {
// expected-error @+1 {{op callee function 'f_undefined' not found in 'spv.module'}}
// expected-error @+1 {{op callee function 'f_undefined' not found in nearest symbol table}}
%0 = spv.FunctionCall @f_undefined(%arg0, %arg0) : (i32, i32) -> i32
spv.Return
}
@ -247,14 +260,6 @@ spv.module "Logical" "GLSL450" {
// -----
func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 {
// expected-error @+1 {{must appear in a function inside 'spv.module'}}
%0 = spv.FunctionCall @f_foo(%arg0, %arg0) : (i32, i32) -> i32
spv.Return
}
// -----
//===----------------------------------------------------------------------===//
// spv.loop
//===----------------------------------------------------------------------===//
@ -497,8 +502,16 @@ func @in_loop(%cond : i1) -> () {
// -----
// CHECK-LABEL: in_other_func_like_op
func @in_other_func_like_op() {
// CHECK: spv.Return
spv.Return
}
// -----
"foo.function"() ({
// expected-error @+1 {{op must appear in a 'func' block}}
// expected-error @+1 {{op must appear in a function-like op's block}}
spv.Return
}) : () -> ()
@ -562,7 +575,7 @@ func @in_loop(%cond : i1) -> (i32) {
"foo.function"() ({
%0 = spv.constant true
// expected-error @+1 {{op must appear in a 'func' block}}
// expected-error @+1 {{op must appear in a function-like op's block}}
spv.ReturnValue %0 : i1
}) : () -> ()

View File

@ -18,6 +18,16 @@ spv.module "Logical" "GLSL450" {
// -----
// Allow taking address of global variables in other module-like ops
spv.globalVariable @var : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
func @address_of() -> () {
// CHECK: spv._address_of @var
%1 = spv._address_of @var : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
return
}
// -----
spv.module "Logical" "GLSL450" {
spv.globalVariable @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
func @foo() -> () {
@ -174,7 +184,7 @@ spv.module "Logical" "GLSL450" {
spv.module "Logical" "GLSL450" {
func @do_nothing() -> () {
// expected-error @+1 {{'spv.EntryPoint' op failed to verify that op must appear in a 'spv.module' block}}
// expected-error @+1 {{op must appear in a module-like op's block}}
spv.EntryPoint "GLCompute" @do_something
}
}
@ -229,6 +239,13 @@ spv.module "Logical" "GLSL450" {
// -----
// Allow initializers coming from other module-like ops
spv.specConstant @sc = 4.0 : f32
// CHECK: spv.globalVariable @var initializer(@sc)
spv.globalVariable @var initializer(@sc) : !spv.ptr<f32, Private>
// -----
spv.module "Logical" "GLSL450" {
// CHECK: spv.globalVariable @var0 bind(1, 2) : !spv.ptr<f32, Uniform>
spv.globalVariable @var0 bind(1, 2) : !spv.ptr<f32, Uniform>
@ -252,6 +269,14 @@ spv.module "Logical" "GLSL450" {
// -----
// Allow in other module-like ops
module {
// CHECK: spv.globalVariable
spv.globalVariable @var0 : !spv.ptr<f32, Input>
}
// -----
spv.module "Logical" "GLSL450" {
// expected-error @+1 {{expected spv.ptr type}}
spv.globalVariable @var0 : f32
@ -275,7 +300,7 @@ spv.module "Logical" "GLSL450" {
spv.module "Logical" "GLSL450" {
func @foo() {
// expected-error @+1 {{op failed to verify that op must appear in a 'spv.module' block}}
// expected-error @+1 {{op must appear in a module-like op's block}}
spv.globalVariable @var0 : !spv.ptr<f32, Input>
spv.Return
}
@ -418,7 +443,7 @@ spv.module "Logical" "GLSL450" {
//===----------------------------------------------------------------------===//
func @module_end_not_in_module() -> () {
// expected-error @+1 {{op must appear in a 'spv.module' block}}
// expected-error @+1 {{op must appear in a module-like op's block}}
spv._module_end
}
@ -461,6 +486,16 @@ spv.module "Logical" "GLSL450" {
// -----
// Allow taking reference of spec constant in other module-like ops
spv.specConstant @sc = 5 : i32
func @reference_of() {
// CHECK: spv._reference_of @sc
%0 = spv._reference_of @sc : i32
return
}
// -----
spv.module "Logical" "GLSL450" {
func @foo() -> () {
// expected-error @+1 {{expected spv.specConstant symbol}}
@ -519,7 +554,7 @@ spv.module "Logical" "GLSL450" {
// -----
func @use_in_function() -> () {
// expected-error @+1 {{op must appear in a 'spv.module' block}}
// expected-error @+1 {{op must appear in a module-like op's block}}
spv.specConstant @sc = false
return
}