Print parens around the return type of a function if it is also a function type

Existing type syntax contains the following productions:

    function-type ::= type-list-parens `->` type-list
    type-list ::= type | type-list-parens
    type ::= <..> | function-type

Due to these rules, when the parser sees `->` followed by `(`, it cannot
disambiguate if `(` starts a parenthesized list of function result types, or a
parenthesized list of operands of another function type, returned from the
current function.  We would need an unknown amount of lookahead to try to find
the `->` at the right level of function nesting to differentiate between type
lists and singular function types.

Instead, require the result type of the function that is a function type itself
to be always parenthesized, at the syntax level.  Update the spec and the
parser to correspond to the production rule names used in the spec (although it
would have worked without modifications).  Fix the function type parsing bug in
the process, as it used to accept the non-parenthesized list of types for
arguments, disallowed by the spec.

PiperOrigin-RevId: 232528361
This commit is contained in:
Alex Zinenko 2019-02-05 11:47:02 -08:00 committed by jpienaar
parent 1b1f293a5d
commit 40d5d09f9d
8 changed files with 127 additions and 60 deletions

View File

@ -504,18 +504,18 @@ have application-specific semantics. For example, MLIR supports a set of
[dialect-specific types](#dialect-specific-types). [dialect-specific types](#dialect-specific-types).
``` {.ebnf} ``` {.ebnf}
type ::= integer-type type ::= non-function-type
| index-type
| float-type
| vector-type
| tensor-type
| memref-type
| function-type | function-type
| dialect-type
| type-alias
// MLIR doesn't have a tuple type but functions can return multiple values. non-function-type ::= integer-type
type-list ::= type-list-parens | type | index-type
| float-type
| vector-type
| tensor-type
| memref-type
| dialect-type
| type-alias
type-list-no-parens ::= type (`,` type)* type-list-no-parens ::= type (`,` type)*
type-list-parens ::= `(` `)` type-list-parens ::= `(` `)`
| `(` type-list-no-parens `)` | `(` type-list-no-parens `)`
@ -559,7 +559,11 @@ Builtin types consist of only the types needed for the validity of the IR.
Syntax: Syntax:
``` {.ebnf} ``` {.ebnf}
function-type ::= type-list-parens `->` type-list // MLIR doesn't have a tuple type but functions can return multiple values.
function-result-type ::= type-list-parens
| non-function-type
function-type ::= type-list-parens `->` function-result-type
``` ```
MLIR supports first-class functions: the MLIR supports first-class functions: the
@ -897,7 +901,7 @@ associated attributes according to the following grammar:
``` {.ebnf} ``` {.ebnf}
function ::= `func` function-signature function-attributes? function-body? function ::= `func` function-signature function-attributes? function-body?
function-signature ::= function-id `(` argument-list `)` (`->` type-list)? function-signature ::= function-id `(` argument-list `)` (`->` function-result-type)?
argument-list ::= named-argument (`,` named-argument)* | /*empty*/ argument-list ::= named-argument (`,` named-argument)* | /*empty*/
argument-list ::= type (`,` type)* | /*empty*/ named-argument ::= ssa-id `:` argument-list ::= type (`,` type)* | /*empty*/ named-argument ::= ssa-id `:`
type type

View File

@ -746,7 +746,7 @@ void ModulePrinter::printType(Type type) {
interleaveComma(func.getInputs(), [&](Type type) { printType(type); }); interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
os << ") -> "; os << ") -> ";
auto results = func.getResults(); auto results = func.getResults();
if (results.size() == 1) if (results.size() == 1 && !results[0].isa<FunctionType>())
os << results[0]; os << results[0];
else { else {
os << '('; os << '(';
@ -1313,10 +1313,17 @@ void FunctionPrinter::printFunctionSignature() {
switch (fnType.getResults().size()) { switch (fnType.getResults().size()) {
case 0: case 0:
break; break;
case 1: case 1: {
os << " -> "; os << " -> ";
printType(fnType.getResults()[0]); auto resultType = fnType.getResults()[0];
bool resultIsFunc = resultType.isa<FunctionType>();
if (resultIsFunc)
os << '(';
printType(resultType);
if (resultIsFunc)
os << ')';
break; break;
}
default: default:
os << " -> ("; os << " -> (";
interleaveComma(fnType.getResults(), interleaveComma(fnType.getResults(),
@ -1482,7 +1489,8 @@ void FunctionPrinter::printGenericOp(const Instruction *op) {
[&](const Value *value) { printType(value->getType()); }); [&](const Value *value) { printType(value->getType()); });
os << ") -> "; os << ") -> ";
if (op->getNumResults() == 1) { if (op->getNumResults() == 1 &&
!op->getResult(0)->getType().isa<FunctionType>()) {
printType(op->getResult(0)->getType()); printType(op->getResult(0)->getType());
} else { } else {
os << '('; os << '(';

View File

@ -187,9 +187,11 @@ public:
Type parseTensorType(); Type parseTensorType();
Type parseMemRefType(); Type parseMemRefType();
Type parseFunctionType(); Type parseFunctionType();
Type parseNonFunctionType();
Type parseType(); Type parseType();
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
ParseResult parseTypeList(SmallVectorImpl<Type> &elements); ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
// Attribute parsing. // Attribute parsing.
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc, Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
@ -308,32 +310,29 @@ ParseResult Parser::parseCommaSeparatedListUntil(
// Type Parsing // Type Parsing
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Parse an arbitrary type. /// Parse any type except the function type.
/// ///
/// type ::= integer-type /// non-function-type ::= integer-type
/// | index-type /// | index-type
/// | float-type /// | float-type
/// | extended-type /// | extended-type
/// | vector-type /// | vector-type
/// | tensor-type /// | tensor-type
/// | memref-type /// | memref-type
/// | function-type
/// ///
/// index-type ::= `index` /// index-type ::= `index`
/// float-type ::= `f16` | `bf16` | `f32` | `f64` /// float-type ::= `f16` | `bf16` | `f32` | `f64`
/// ///
Type Parser::parseType() { Type Parser::parseNonFunctionType() {
switch (getToken().getKind()) { switch (getToken().getKind()) {
default: default:
return (emitError("expected type"), nullptr); return (emitError("expected non-function type"), nullptr);
case Token::kw_memref: case Token::kw_memref:
return parseMemRefType(); return parseMemRefType();
case Token::kw_tensor: case Token::kw_tensor:
return parseTensorType(); return parseTensorType();
case Token::kw_vector: case Token::kw_vector:
return parseVectorType(); return parseVectorType();
case Token::l_paren:
return parseFunctionType();
// integer-type // integer-type
case Token::inttype: { case Token::inttype: {
auto width = getToken().getIntTypeBitwidth(); auto width = getToken().getIntTypeBitwidth();
@ -369,6 +368,17 @@ Type Parser::parseType() {
} }
} }
/// Parse an arbitrary type.
///
/// type ::= function-type
/// | non-function-type
///
Type Parser::parseType() {
if (getToken().is(Token::l_paren))
return parseFunctionType();
return parseNonFunctionType();
}
/// Parse a vector type. /// Parse a vector type.
/// ///
/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>` /// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
@ -640,9 +650,9 @@ Type Parser::parseFunctionType() {
assert(getToken().is(Token::l_paren)); assert(getToken().is(Token::l_paren));
SmallVector<Type, 4> arguments, results; SmallVector<Type, 4> arguments, results;
if (parseTypeList(arguments) || if (parseTypeListParens(arguments) ||
parseToken(Token::arrow, "expected '->' in function type") || parseToken(Token::arrow, "expected '->' in function type") ||
parseTypeList(results)) parseFunctionResultTypes(results))
return nullptr; return nullptr;
return builder.getFunctionType(arguments, results); return builder.getFunctionType(arguments, results);
@ -663,27 +673,38 @@ ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
return parseCommaSeparatedList(parseElt); return parseCommaSeparatedList(parseElt);
} }
/// Parse a "type list", which is a singular type, or a parenthesized list of /// Parse a parenthesized list of types.
/// types.
/// ///
/// type-list ::= type-list-parens | type
/// type-list-parens ::= `(` `)` /// type-list-parens ::= `(` `)`
/// | `(` type-list-no-parens `)` /// | `(` type-list-no-parens `)`
/// ///
ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) { ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
auto parseElt = [&]() -> ParseResult { if (parseToken(Token::l_paren, "expected '('"))
auto elt = parseType();
elements.push_back(elt);
return elt ? ParseSuccess : ParseFailure;
};
// If there is no parens, then it must be a singular type.
if (!consumeIf(Token::l_paren))
return parseElt();
if (parseCommaSeparatedListUntil(Token::r_paren, parseElt))
return ParseFailure; return ParseFailure;
// Handle empty lists.
if (getToken().is(Token::r_paren))
return consumeToken(), ParseSuccess;
if (parseTypeListNoParens(elements) ||
parseToken(Token::r_paren, "expected ')'"))
return ParseFailure;
return ParseSuccess;
}
/// Parse a function result type.
///
/// function-result-type ::= type-list-parens
/// | non-function-type
///
ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
if (getToken().is(Token::l_paren))
return parseTypeListParens(elements);
Type t = parseNonFunctionType();
if (!t)
return ParseFailure;
elements.push_back(t);
return ParseSuccess; return ParseSuccess;
} }
@ -3489,7 +3510,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
// Parse the return type if present. // Parse the return type if present.
SmallVector<Type, 4> results; SmallVector<Type, 4> results;
if (consumeIf(Token::arrow)) { if (consumeIf(Token::arrow)) {
if (parseTypeList(results)) if (parseFunctionResultTypes(results))
return ParseFailure; return ParseFailure;
} }
type = builder.getFunctionType(argTypes, results); type = builder.getFunctionType(argTypes, results);

View File

@ -1,8 +1,8 @@
// RUN: mlir-opt %s | FileCheck %s // RUN: mlir-opt %s | FileCheck %s
// Verify the printed output can be parsed. // Verify the printed output can be parsed.
// RUN: mlir-opt %s | mlir-opt | FileCheck %s // RUN: mlir-opt %s | mlir-opt | FileCheck %s
// TODO(b/123888077): The following fails due to constant with function pointer. // Verify the generic form can be parsed.
// Disabled: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s // RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
// CHECK: #map0 = (d0) -> (d0 + 1) // CHECK: #map0 = (d0) -> (d0 + 1)

View File

@ -67,7 +67,7 @@ func @location_fused_missing_greater() {
func @location_fused_missing_metadata() { func @location_fused_missing_metadata() {
^bb: ^bb:
// expected-error@+1 {{expected type}} // expected-error@+1 {{expected non-function type}}
return loc(fused<) // expected-error {{expected valid attribute metadata}} return loc(fused<) // expected-error {{expected valid attribute metadata}}
} }

View File

@ -209,7 +209,7 @@ func @func_with_ops(i32, i64) {
// Comparisons must have the "predicate" attribute. // Comparisons must have the "predicate" attribute.
func @func_with_ops(i32, i32) { func @func_with_ops(i32, i32) {
^bb0(%a : i32, %b : i32): ^bb0(%a : i32, %b : i32):
%r = cmpi %a, %b : i32 // expected-error {{expected type}} %r = cmpi %a, %b : i32 // expected-error {{expected non-function type}}
} }
// ----- // -----

View File

@ -3,12 +3,12 @@
// Check different error cases. // Check different error cases.
// ----- // -----
func @illegaltype(i) // expected-error {{expected type}} func @illegaltype(i) // expected-error {{expected non-function type}}
// ----- // -----
func @illegaltype() { func @illegaltype() {
%0 = constant splat<<vector 4 x f32>, 0> : vector<4 x f32> // expected-error {{expected type}} %0 = constant splat<<vector 4 x f32>, 0> : vector<4 x f32> // expected-error {{expected non-function type}}
} }
// ----- // -----
@ -227,7 +227,7 @@ func @incomplete_for() {
// ----- // -----
func @nonconstant_step(%1 : i32) { func @nonconstant_step(%1 : i32) {
for %2 = 1 to 5 step %1 { // expected-error {{expected type}} for %2 = 1 to 5 step %1 { // expected-error {{expected non-function type}}
// ----- // -----
@ -326,7 +326,7 @@ func @d() {return} // expected-error {{custom op 'func' is unknown}}
// ----- // -----
func @malformed_type(%a : intt) { // expected-error {{expected type}} func @malformed_type(%a : intt) { // expected-error {{expected non-function type}}
} }
// ----- // -----
@ -392,7 +392,7 @@ func @condbr_badtype() {
^bb0: ^bb0:
%c = "foo"() : () -> i1 %c = "foo"() : () -> i1
%a = "foo"() : () -> i32 %a = "foo"() : () -> i32
cond_br %c, ^bb0(%a, %a : i32, ^bb0) // expected-error {{expected type}} cond_br %c, ^bb0(%a, %a : i32, ^bb0) // expected-error {{expected non-function type}}
} }
// ----- // -----
@ -506,6 +506,18 @@ func @undefined_function() {
// ----- // -----
func @invalid_result_type() -> () -> () // expected-error {{expected a top level entity}}
// -----
func @func() -> (() -> ())
func @referer() {
%f = constant @func : () -> () -> () // expected-error {{reference to function with mismatched type}}
return
}
// -----
#map1 = (i)[j] -> (i+j) #map1 = (i)[j] -> (i+j)
func @bound_symbol_mismatch(%N : index) { func @bound_symbol_mismatch(%N : index) {
@ -538,7 +550,7 @@ func @large_bound() {
// ----- // -----
func @max_in_upper_bound(%N : index) { func @max_in_upper_bound(%N : index) {
for %i = 1 to max (i)->(N, 100) { //expected-error {{expected type}} for %i = 1 to max (i)->(N, 100) { //expected-error {{expected non-function type}}
} }
return return
} }
@ -595,7 +607,7 @@ func @invalid_if_operands3(%N : index) {
// expected-error@+1 {{expected '"' in string literal}} // expected-error@+1 {{expected '"' in string literal}}
"J// ----- "J// -----
func @calls(%arg0: i32) { func @calls(%arg0: i32) {
// expected-error@+1 {{expected type}} // expected-error@+1 {{expected non-function type}}
%z = "casdasda"(%x) : (ppop32) -> i32 %z = "casdasda"(%x) : (ppop32) -> i32
} }
// ----- // -----
@ -767,7 +779,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined t
// ----- // -----
!missing_type_alias = type // expected-error@+2 {{expected type}} !missing_type_alias = type // expected-error@+2 {{expected non-function type}}
// ----- // -----

View File

@ -182,6 +182,28 @@ func @func_with_two_args(%a : f16, %b : i8) -> (i1, i32) {
return %c#0, %c#1 : i1, i32 // CHECK: return %0#0, %0#1 : i1, i32 return %c#0, %c#1 : i1, i32 // CHECK: return %0#0, %0#1 : i1, i32
} // CHECK: } } // CHECK: }
// CHECK-LABEL: func @second_order_func() -> (() -> ()) {
func @second_order_func() -> (() -> ()) {
// CHECK-NEXT: %f = constant @emptyMLF : () -> ()
%c = constant @emptyMLF : () -> ()
// CHECK-NEXT: return %f : () -> ()
return %c : () -> ()
}
// CHECK-LABEL: func @third_order_func() -> (() -> (() -> ())) {
func @third_order_func() -> (() -> (() -> ())) {
// CHECK-NEXT: %f = constant @second_order_func : () -> (() -> ())
%c = constant @second_order_func : () -> (() -> ())
// CHECK-NEXT: return %f : () -> (() -> ())
return %c : () -> (() -> ())
}
// CHECK-LABEL: func @identity_functor(%arg0: () -> ()) -> (() -> ()) {
func @identity_functor(%a : () -> ()) -> (() -> ()) {
// CHECK-NEXT: return %arg0 : () -> ()
return %a : () -> ()
}
// CHECK-LABEL: func @func_ops_in_loop() { // CHECK-LABEL: func @func_ops_in_loop() {
func @func_ops_in_loop() { func @func_ops_in_loop() {
// CHECK: %0 = "foo"() : () -> i64 // CHECK: %0 = "foo"() : () -> i64