diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 8f07fecb9f06..cf87bfd90cd0 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -128,6 +128,7 @@ def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>; def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>; def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>; def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; +def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; def SPV_OpcodeAttr : I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ @@ -146,7 +147,7 @@ def SPV_OpcodeAttr : SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn + SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn, SPV_OC_OpReturnValue ]> { let returnType = "::mlir::spirv::Opcode"; let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())"; @@ -778,12 +779,15 @@ def SPV_SamplerUseAttr: // SPIR-V OpTrait definitions //===----------------------------------------------------------------------===// -// Check that an op can only be used with SPIR-V ModuleOp -def IsModuleOnlyPred : - CPred<"llvm::isa_and_nonnull($_op.getParentOp())">; +// Check that an op can only be used within the scope of a FuncOp. +def InFunctionScope : PredOpTrait< + "op must appear in a 'func' block", + CPred<"llvm::isa_and_nonnull($_op.getParentOp())">>; -def ModuleOnly : - PredOpTrait<"op can only be used in a 'spv.module' block", IsModuleOnlyPred>; +// Check that an op can only be used within the scope of a SPIR-V ModuleOp. +def InModuleScope : PredOpTrait< + "op must appear in a 'spv.module' block", + CPred<"llvm::isa_and_nonnull($_op.getParentOp())">>; //===----------------------------------------------------------------------===// // SPIR-V op definitions diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index de496a76d26d..76bffde38df4 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -146,7 +146,7 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { // ----- -def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> { +def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> { let summary = "Declare an execution mode for an entry point."; let description = [{ @@ -599,7 +599,7 @@ def SPV_LoadOp : SPV_Op<"Load", []> { // ----- -def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> { +def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> { let summary = "Return with no value from a function with void return type."; let description = [{ @@ -624,6 +624,38 @@ def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> { // ----- +def SPV_ReturnValueOp : SPV_Op<"ReturnValue", [InFunctionScope, Terminator]> { + let summary = "Return a value from a function."; + + let description = [{ + Value is the value returned, by copy, and must match the Return Type + operand of the OpTypeFunction type of the OpFunction body this return + instruction is in. + + This instruction must be the last instruction in a block. + + ### Custom assembly form + + ``` {.ebnf} + return-value-op ::= `spv.ReturnValue` ssa-use `:` spirv-type + ``` + + For example: + + ``` + spv.ReturnValue %0 : f32 + ``` + }]; + + let arguments = (ins + SPV_Type:$value + ); + + let results = (outs); +} + +// ----- + def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer> { let summary = "Signed-integer division of Operand 1 divided by Operand 2."; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index d47563907428..292e148c86fb 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -30,7 +30,7 @@ include "mlir/SPIRV/SPIRVBase.td" #endif // SPIRV_BASE -def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> { +def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> { let summary = "Get the address of a global variable."; let description = [{ @@ -66,7 +66,7 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> { let hasOpcode = 0; } -def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> { +def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> { let summary = [{ Declare an entry point, its execution model, and its interface. }]; @@ -122,7 +122,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> { } -def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [ModuleOnly]> { +def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope]> { let summary = [{ Allocate an object in memory at module scope. The object is referenced using a symbol name. @@ -264,7 +264,7 @@ def SPV_ModuleOp : SPV_Op<"module", }]; } -def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator, ModuleOnly]> { +def SPV_ModuleEndOp : SPV_Op<"_module_end", [InModuleScope, Terminator]> { let summary = "The pseudo op that ends a SPIR-V module"; let description = [{ diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 9947c0254a9a..9a7f3594551e 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1042,10 +1042,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) { //===----------------------------------------------------------------------===// static LogicalResult verifyReturn(spirv::ReturnOp returnOp) { - auto funcOp = llvm::dyn_cast(returnOp.getOperation()->getParentOp()); - if (!funcOp) - return returnOp.emitOpError("must appear in a 'func' op"); - + auto funcOp = llvm::cast(returnOp.getParentOp()); auto numOutputs = funcOp.getType().getNumResults(); if (numOutputs != 0) return returnOp.emitOpError("cannot be used in functions returning value") @@ -1054,6 +1051,43 @@ static LogicalResult verifyReturn(spirv::ReturnOp returnOp) { return success(); } +//===----------------------------------------------------------------------===// +// spv.ReturnValue +//===----------------------------------------------------------------------===// + +static ParseResult parseReturnValueOp(OpAsmParser *parser, + OperationState *state) { + OpAsmParser::OperandType retValInfo; + Type retValType; + return failure( + parser->parseOperand(retValInfo) || parser->parseColonType(retValType) || + parser->resolveOperand(retValInfo, retValType, state->operands)); +} + +static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter *printer) { + *printer << spirv::ReturnValueOp::getOperationName() << ' '; + printer->printOperand(retValOp.value()); + *printer << " : " << retValOp.value()->getType(); +} + +static LogicalResult verify(spirv::ReturnValueOp retValOp) { + auto funcOp = llvm::cast(retValOp.getParentOp()); + auto numFnResults = funcOp.getType().getNumResults(); + if (numFnResults != 1) + return retValOp.emitOpError( + "returns 1 value but enclosing function requires ") + << numFnResults << " results"; + + auto operandType = retValOp.value()->getType(); + auto fnResultType = funcOp.getType().getResult(0); + if (operandType != fnResultType) + return retValOp.emitOpError(" return value's type (") + << operandType << ") mismatch with function's result type (" + << fnResultType << ")"; + + return success(); +} + //===----------------------------------------------------------------------===// // spv.StoreOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/terminator.mlir b/mlir/test/Dialect/SPIRV/Serialization/terminator.mlir new file mode 100644 index 000000000000..35d2f972b555 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/terminator.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s + +func @spirv_terminator() -> () { + spv.module "Logical" "GLSL450" { + // CHECK-LABEL: @ret + func @ret() -> () { + // CHECK: spv.Return + spv.Return + } + + // CHECK-LABEL: @ret_val + func @ret_val() -> (i32) { + %0 = spv.Variable : !spv.ptr + %1 = spv.Load "Function" %0 : i32 + // CHECK: spv.ReturnValue {{.*}} : i32 + spv.ReturnValue %1 : i32 + } + } + return +} + diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index 052dc6871679..167b6d813430 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -327,7 +327,7 @@ spv.module "Logical" "VulkanKHR" { spv.module "Logical" "VulkanKHR" { func @do_nothing() -> () { - // expected-error @+1 {{'spv.EntryPoint' op failed to verify that op can only be used in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a 'spv.module' block}} spv.EntryPoint "GLCompute" @do_something } } @@ -451,7 +451,7 @@ spv.module "Logical" "VulkanKHR" { spv.module "Logical" "VulkanKHR" { func @foo() { - // expected-error @+1 {{op failed to verify that op can only be used in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a 'spv.module' block}} spv.globalVariable !spv.ptr @var0 spv.Return } @@ -767,7 +767,7 @@ spv.module "Logical" "VulkanKHR" { //===----------------------------------------------------------------------===// "foo.function"() ({ - // expected-error @+1 {{must appear in a 'func' op}} + // expected-error @+1 {{op must appear in a 'func' block}} spv.Return }) : () -> () @@ -783,6 +783,41 @@ spv.module "Logical" "VulkanKHR" { // ----- +//===----------------------------------------------------------------------===// +// spv.ReturnValue +//===----------------------------------------------------------------------===// + +func @ret_val() -> (i32) { + %0 = spv.constant 42 : i32 + spv.ReturnValue %0 : i32 +} + +// ----- + +"foo.function"() ({ + %0 = spv.constant true + // expected-error @+1 {{op must appear in a 'func' block}} + spv.ReturnValue %0 : i1 +}) : () -> () + +// ----- + +func @value_count_mismatch() -> () { + %0 = spv.constant 42 : i32 + // expected-error @+1 {{op returns 1 value but enclosing function requires 0 results}} + spv.ReturnValue %0 : i32 +} + +// ----- + +func @value_type_mismatch() -> (f32) { + %0 = spv.constant 42 : i32 + // expected-error @+1 {{return value's type ('i32') mismatch with function's result type ('f32')}} + spv.ReturnValue %0 : i32 +} + +// ----- + //===----------------------------------------------------------------------===// // spv.SDiv //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir index db51b175b03b..e398be6656e3 100644 --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -1,5 +1,17 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s +//===----------------------------------------------------------------------===// +// spv._address_of +//===----------------------------------------------------------------------===// + +spv.module "Logical" "GLSL450" { + spv.globalVariable !spv.ptr>, Input> @var + // expected-error @+1 {{op must appear in a 'func' block}} + %1 = spv._address_of @var : !spv.ptr>, Input> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.constant //===----------------------------------------------------------------------===// @@ -171,6 +183,6 @@ spv.module "Logical" "VulkanKHR" { //===----------------------------------------------------------------------===// func @module_end_not_in_module() -> () { - // expected-error @+1 {{can only be used in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a 'spv.module' block}} spv._module_end }