[spirv] Add spv.ReturnValue

This CL adds the spv.ReturnValue op and its tests. Also adds a
InFunctionScope trait to make sure that the op stays inside
a function. To be consistent, ModuleOnly trait is changed to
InModuleScope.

PiperOrigin-RevId: 264193081
This commit is contained in:
Lei Zhang 2019-08-19 10:57:43 -07:00 committed by A. Unique TensorFlower
parent 9bf69e6a2e
commit 64abcd983d
7 changed files with 158 additions and 20 deletions

View File

@ -128,6 +128,7 @@ def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>;
def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>; def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>;
def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>; def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
def SPV_OpcodeAttr : def SPV_OpcodeAttr :
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@ -146,7 +147,7 @@ def SPV_OpcodeAttr :
SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn, SPV_OC_OpReturnValue
]> { ]> {
let returnType = "::mlir::spirv::Opcode"; let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())"; let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
@ -778,12 +779,15 @@ def SPV_SamplerUseAttr:
// SPIR-V OpTrait definitions // SPIR-V OpTrait definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Check that an op can only be used with SPIR-V ModuleOp // Check that an op can only be used within the scope of a FuncOp.
def IsModuleOnlyPred : def InFunctionScope : PredOpTrait<
CPred<"llvm::isa_and_nonnull<spirv::ModuleOp>($_op.getParentOp())">; "op must appear in a 'func' block",
CPred<"llvm::isa_and_nonnull<FuncOp>($_op.getParentOp())">>;
def ModuleOnly : // Check that an op can only be used within the scope of a SPIR-V ModuleOp.
PredOpTrait<"op can only be used in a 'spv.module' block", IsModuleOnlyPred>; def InModuleScope : PredOpTrait<
"op must appear in a 'spv.module' block",
CPred<"llvm::isa_and_nonnull<spirv::ModuleOp>($_op.getParentOp())">>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// SPIR-V op definitions // SPIR-V op definitions

View File

@ -146,7 +146,7 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
// ----- // -----
def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> { def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> {
let summary = "Declare an execution mode for an entry point."; let summary = "Declare an execution mode for an entry point.";
let description = [{ let description = [{
@ -599,7 +599,7 @@ def SPV_LoadOp : SPV_Op<"Load", []> {
// ----- // -----
def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> { def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> {
let summary = "Return with no value from a function with void return type."; let summary = "Return with no value from a function with void return type.";
let description = [{ let description = [{
@ -624,6 +624,38 @@ def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
// ----- // -----
def SPV_ReturnValueOp : SPV_Op<"ReturnValue", [InFunctionScope, Terminator]> {
let summary = "Return a value from a function.";
let description = [{
Value is the value returned, by copy, and must match the Return Type
operand of the OpTypeFunction type of the OpFunction body this return
instruction is in.
This instruction must be the last instruction in a block.
### Custom assembly form
``` {.ebnf}
return-value-op ::= `spv.ReturnValue` ssa-use `:` spirv-type
```
For example:
```
spv.ReturnValue %0 : f32
```
}];
let arguments = (ins
SPV_Type:$value
);
let results = (outs);
}
// -----
def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer> { def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer> {
let summary = "Signed-integer division of Operand 1 divided by Operand 2."; let summary = "Signed-integer division of Operand 1 divided by Operand 2.";

View File

@ -30,7 +30,7 @@
include "mlir/SPIRV/SPIRVBase.td" include "mlir/SPIRV/SPIRVBase.td"
#endif // SPIRV_BASE #endif // SPIRV_BASE
def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> { def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
let summary = "Get the address of a global variable."; let summary = "Get the address of a global variable.";
let description = [{ let description = [{
@ -66,7 +66,7 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> {
let hasOpcode = 0; let hasOpcode = 0;
} }
def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> { def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> {
let summary = [{ let summary = [{
Declare an entry point, its execution model, and its interface. Declare an entry point, its execution model, and its interface.
}]; }];
@ -122,7 +122,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
} }
def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [ModuleOnly]> { def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope]> {
let summary = [{ let summary = [{
Allocate an object in memory at module scope. The object is Allocate an object in memory at module scope. The object is
referenced using a symbol name. referenced using a symbol name.
@ -264,7 +264,7 @@ def SPV_ModuleOp : SPV_Op<"module",
}]; }];
} }
def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator, ModuleOnly]> { def SPV_ModuleEndOp : SPV_Op<"_module_end", [InModuleScope, Terminator]> {
let summary = "The pseudo op that ends a SPIR-V module"; let summary = "The pseudo op that ends a SPIR-V module";
let description = [{ let description = [{

View File

@ -1042,10 +1042,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verifyReturn(spirv::ReturnOp returnOp) { static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
auto funcOp = llvm::dyn_cast<FuncOp>(returnOp.getOperation()->getParentOp()); auto funcOp = llvm::cast<FuncOp>(returnOp.getParentOp());
if (!funcOp)
return returnOp.emitOpError("must appear in a 'func' op");
auto numOutputs = funcOp.getType().getNumResults(); auto numOutputs = funcOp.getType().getNumResults();
if (numOutputs != 0) if (numOutputs != 0)
return returnOp.emitOpError("cannot be used in functions returning value") return returnOp.emitOpError("cannot be used in functions returning value")
@ -1054,6 +1051,43 @@ static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// spv.ReturnValue
//===----------------------------------------------------------------------===//
static ParseResult parseReturnValueOp(OpAsmParser *parser,
OperationState *state) {
OpAsmParser::OperandType retValInfo;
Type retValType;
return failure(
parser->parseOperand(retValInfo) || parser->parseColonType(retValType) ||
parser->resolveOperand(retValInfo, retValType, state->operands));
}
static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter *printer) {
*printer << spirv::ReturnValueOp::getOperationName() << ' ';
printer->printOperand(retValOp.value());
*printer << " : " << retValOp.value()->getType();
}
static LogicalResult verify(spirv::ReturnValueOp retValOp) {
auto funcOp = llvm::cast<FuncOp>(retValOp.getParentOp());
auto numFnResults = funcOp.getType().getNumResults();
if (numFnResults != 1)
return retValOp.emitOpError(
"returns 1 value but enclosing function requires ")
<< numFnResults << " results";
auto operandType = retValOp.value()->getType();
auto fnResultType = funcOp.getType().getResult(0);
if (operandType != fnResultType)
return retValOp.emitOpError(" return value's type (")
<< operandType << ") mismatch with function's result type ("
<< fnResultType << ")";
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// spv.StoreOp // spv.StoreOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,21 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
func @spirv_terminator() -> () {
spv.module "Logical" "GLSL450" {
// CHECK-LABEL: @ret
func @ret() -> () {
// CHECK: spv.Return
spv.Return
}
// CHECK-LABEL: @ret_val
func @ret_val() -> (i32) {
%0 = spv.Variable : !spv.ptr<i32, Function>
%1 = spv.Load "Function" %0 : i32
// CHECK: spv.ReturnValue {{.*}} : i32
spv.ReturnValue %1 : i32
}
}
return
}

View File

@ -327,7 +327,7 @@ spv.module "Logical" "VulkanKHR" {
spv.module "Logical" "VulkanKHR" { spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () { func @do_nothing() -> () {
// expected-error @+1 {{'spv.EntryPoint' op failed to verify that op can only be used in a 'spv.module' block}} // expected-error @+1 {{op must appear in a 'spv.module' block}}
spv.EntryPoint "GLCompute" @do_something spv.EntryPoint "GLCompute" @do_something
} }
} }
@ -451,7 +451,7 @@ spv.module "Logical" "VulkanKHR" {
spv.module "Logical" "VulkanKHR" { spv.module "Logical" "VulkanKHR" {
func @foo() { func @foo() {
// expected-error @+1 {{op failed to verify that op can only be used in a 'spv.module' block}} // expected-error @+1 {{op must appear in a 'spv.module' block}}
spv.globalVariable !spv.ptr<f32, Input> @var0 spv.globalVariable !spv.ptr<f32, Input> @var0
spv.Return spv.Return
} }
@ -767,7 +767,7 @@ spv.module "Logical" "VulkanKHR" {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
"foo.function"() ({ "foo.function"() ({
// expected-error @+1 {{must appear in a 'func' op}} // expected-error @+1 {{op must appear in a 'func' block}}
spv.Return spv.Return
}) : () -> () }) : () -> ()
@ -783,6 +783,41 @@ spv.module "Logical" "VulkanKHR" {
// ----- // -----
//===----------------------------------------------------------------------===//
// spv.ReturnValue
//===----------------------------------------------------------------------===//
func @ret_val() -> (i32) {
%0 = spv.constant 42 : i32
spv.ReturnValue %0 : i32
}
// -----
"foo.function"() ({
%0 = spv.constant true
// expected-error @+1 {{op must appear in a 'func' block}}
spv.ReturnValue %0 : i1
}) : () -> ()
// -----
func @value_count_mismatch() -> () {
%0 = spv.constant 42 : i32
// expected-error @+1 {{op returns 1 value but enclosing function requires 0 results}}
spv.ReturnValue %0 : i32
}
// -----
func @value_type_mismatch() -> (f32) {
%0 = spv.constant 42 : i32
// expected-error @+1 {{return value's type ('i32') mismatch with function's result type ('f32')}}
spv.ReturnValue %0 : i32
}
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// spv.SDiv // spv.SDiv
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1,5 +1,17 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// spv._address_of
//===----------------------------------------------------------------------===//
spv.module "Logical" "GLSL450" {
spv.globalVariable !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input> @var
// expected-error @+1 {{op must appear in a 'func' block}}
%1 = spv._address_of @var : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
}
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// spv.constant // spv.constant
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -171,6 +183,6 @@ spv.module "Logical" "VulkanKHR" {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
func @module_end_not_in_module() -> () { func @module_end_not_in_module() -> () {
// expected-error @+1 {{can only be used in a 'spv.module' block}} // expected-error @+1 {{op must appear in a 'spv.module' block}}
spv._module_end spv._module_end
} }