forked from OSchip/llvm-project
Print parens around the return type of a function if it is also a function type
Existing type syntax contains the following productions: function-type ::= type-list-parens `->` type-list type-list ::= type | type-list-parens type ::= <..> | function-type Due to these rules, when the parser sees `->` followed by `(`, it cannot disambiguate if `(` starts a parenthesized list of function result types, or a parenthesized list of operands of another function type, returned from the current function. We would need an unknown amount of lookahead to try to find the `->` at the right level of function nesting to differentiate between type lists and singular function types. Instead, require the result type of the function that is a function type itself to be always parenthesized, at the syntax level. Update the spec and the parser to correspond to the production rule names used in the spec (although it would have worked without modifications). Fix the function type parsing bug in the process, as it used to accept the non-parenthesized list of types for arguments, disallowed by the spec. PiperOrigin-RevId: 232528361
This commit is contained in:
parent
1b1f293a5d
commit
40d5d09f9d
|
@ -504,18 +504,18 @@ have application-specific semantics. For example, MLIR supports a set of
|
|||
[dialect-specific types](#dialect-specific-types).
|
||||
|
||||
``` {.ebnf}
|
||||
type ::= integer-type
|
||||
type ::= non-function-type
|
||||
| function-type
|
||||
|
||||
non-function-type ::= integer-type
|
||||
| index-type
|
||||
| float-type
|
||||
| vector-type
|
||||
| tensor-type
|
||||
| memref-type
|
||||
| function-type
|
||||
| dialect-type
|
||||
| type-alias
|
||||
|
||||
// MLIR doesn't have a tuple type but functions can return multiple values.
|
||||
type-list ::= type-list-parens | type
|
||||
type-list-no-parens ::= type (`,` type)*
|
||||
type-list-parens ::= `(` `)`
|
||||
| `(` type-list-no-parens `)`
|
||||
|
@ -559,7 +559,11 @@ Builtin types consist of only the types needed for the validity of the IR.
|
|||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
function-type ::= type-list-parens `->` type-list
|
||||
// MLIR doesn't have a tuple type but functions can return multiple values.
|
||||
function-result-type ::= type-list-parens
|
||||
| non-function-type
|
||||
|
||||
function-type ::= type-list-parens `->` function-result-type
|
||||
```
|
||||
|
||||
MLIR supports first-class functions: the
|
||||
|
@ -897,7 +901,7 @@ associated attributes according to the following grammar:
|
|||
``` {.ebnf}
|
||||
function ::= `func` function-signature function-attributes? function-body?
|
||||
|
||||
function-signature ::= function-id `(` argument-list `)` (`->` type-list)?
|
||||
function-signature ::= function-id `(` argument-list `)` (`->` function-result-type)?
|
||||
argument-list ::= named-argument (`,` named-argument)* | /*empty*/
|
||||
argument-list ::= type (`,` type)* | /*empty*/ named-argument ::= ssa-id `:`
|
||||
type
|
||||
|
|
|
@ -746,7 +746,7 @@ void ModulePrinter::printType(Type type) {
|
|||
interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
|
||||
os << ") -> ";
|
||||
auto results = func.getResults();
|
||||
if (results.size() == 1)
|
||||
if (results.size() == 1 && !results[0].isa<FunctionType>())
|
||||
os << results[0];
|
||||
else {
|
||||
os << '(';
|
||||
|
@ -1313,10 +1313,17 @@ void FunctionPrinter::printFunctionSignature() {
|
|||
switch (fnType.getResults().size()) {
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
case 1: {
|
||||
os << " -> ";
|
||||
printType(fnType.getResults()[0]);
|
||||
auto resultType = fnType.getResults()[0];
|
||||
bool resultIsFunc = resultType.isa<FunctionType>();
|
||||
if (resultIsFunc)
|
||||
os << '(';
|
||||
printType(resultType);
|
||||
if (resultIsFunc)
|
||||
os << ')';
|
||||
break;
|
||||
}
|
||||
default:
|
||||
os << " -> (";
|
||||
interleaveComma(fnType.getResults(),
|
||||
|
@ -1482,7 +1489,8 @@ void FunctionPrinter::printGenericOp(const Instruction *op) {
|
|||
[&](const Value *value) { printType(value->getType()); });
|
||||
os << ") -> ";
|
||||
|
||||
if (op->getNumResults() == 1) {
|
||||
if (op->getNumResults() == 1 &&
|
||||
!op->getResult(0)->getType().isa<FunctionType>()) {
|
||||
printType(op->getResult(0)->getType());
|
||||
} else {
|
||||
os << '(';
|
||||
|
|
|
@ -187,9 +187,11 @@ public:
|
|||
Type parseTensorType();
|
||||
Type parseMemRefType();
|
||||
Type parseFunctionType();
|
||||
Type parseNonFunctionType();
|
||||
Type parseType();
|
||||
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
|
||||
ParseResult parseTypeList(SmallVectorImpl<Type> &elements);
|
||||
ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
|
||||
ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
|
||||
|
||||
// Attribute parsing.
|
||||
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||
|
@ -308,32 +310,29 @@ ParseResult Parser::parseCommaSeparatedListUntil(
|
|||
// Type Parsing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Parse an arbitrary type.
|
||||
/// Parse any type except the function type.
|
||||
///
|
||||
/// type ::= integer-type
|
||||
/// non-function-type ::= integer-type
|
||||
/// | index-type
|
||||
/// | float-type
|
||||
/// | extended-type
|
||||
/// | vector-type
|
||||
/// | tensor-type
|
||||
/// | memref-type
|
||||
/// | function-type
|
||||
///
|
||||
/// index-type ::= `index`
|
||||
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
|
||||
///
|
||||
Type Parser::parseType() {
|
||||
Type Parser::parseNonFunctionType() {
|
||||
switch (getToken().getKind()) {
|
||||
default:
|
||||
return (emitError("expected type"), nullptr);
|
||||
return (emitError("expected non-function type"), nullptr);
|
||||
case Token::kw_memref:
|
||||
return parseMemRefType();
|
||||
case Token::kw_tensor:
|
||||
return parseTensorType();
|
||||
case Token::kw_vector:
|
||||
return parseVectorType();
|
||||
case Token::l_paren:
|
||||
return parseFunctionType();
|
||||
// integer-type
|
||||
case Token::inttype: {
|
||||
auto width = getToken().getIntTypeBitwidth();
|
||||
|
@ -369,6 +368,17 @@ Type Parser::parseType() {
|
|||
}
|
||||
}
|
||||
|
||||
/// Parse an arbitrary type.
|
||||
///
|
||||
/// type ::= function-type
|
||||
/// | non-function-type
|
||||
///
|
||||
Type Parser::parseType() {
|
||||
if (getToken().is(Token::l_paren))
|
||||
return parseFunctionType();
|
||||
return parseNonFunctionType();
|
||||
}
|
||||
|
||||
/// Parse a vector type.
|
||||
///
|
||||
/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
|
||||
|
@ -640,9 +650,9 @@ Type Parser::parseFunctionType() {
|
|||
assert(getToken().is(Token::l_paren));
|
||||
|
||||
SmallVector<Type, 4> arguments, results;
|
||||
if (parseTypeList(arguments) ||
|
||||
if (parseTypeListParens(arguments) ||
|
||||
parseToken(Token::arrow, "expected '->' in function type") ||
|
||||
parseTypeList(results))
|
||||
parseFunctionResultTypes(results))
|
||||
return nullptr;
|
||||
|
||||
return builder.getFunctionType(arguments, results);
|
||||
|
@ -663,27 +673,38 @@ ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
|
|||
return parseCommaSeparatedList(parseElt);
|
||||
}
|
||||
|
||||
/// Parse a "type list", which is a singular type, or a parenthesized list of
|
||||
/// types.
|
||||
/// Parse a parenthesized list of types.
|
||||
///
|
||||
/// type-list ::= type-list-parens | type
|
||||
/// type-list-parens ::= `(` `)`
|
||||
/// | `(` type-list-no-parens `)`
|
||||
///
|
||||
ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
auto elt = parseType();
|
||||
elements.push_back(elt);
|
||||
return elt ? ParseSuccess : ParseFailure;
|
||||
};
|
||||
|
||||
// If there is no parens, then it must be a singular type.
|
||||
if (!consumeIf(Token::l_paren))
|
||||
return parseElt();
|
||||
|
||||
if (parseCommaSeparatedListUntil(Token::r_paren, parseElt))
|
||||
ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
|
||||
if (parseToken(Token::l_paren, "expected '('"))
|
||||
return ParseFailure;
|
||||
|
||||
// Handle empty lists.
|
||||
if (getToken().is(Token::r_paren))
|
||||
return consumeToken(), ParseSuccess;
|
||||
|
||||
if (parseTypeListNoParens(elements) ||
|
||||
parseToken(Token::r_paren, "expected ')'"))
|
||||
return ParseFailure;
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
/// Parse a function result type.
|
||||
///
|
||||
/// function-result-type ::= type-list-parens
|
||||
/// | non-function-type
|
||||
///
|
||||
ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
|
||||
if (getToken().is(Token::l_paren))
|
||||
return parseTypeListParens(elements);
|
||||
|
||||
Type t = parseNonFunctionType();
|
||||
if (!t)
|
||||
return ParseFailure;
|
||||
elements.push_back(t);
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
|
@ -3489,7 +3510,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
|
|||
// Parse the return type if present.
|
||||
SmallVector<Type, 4> results;
|
||||
if (consumeIf(Token::arrow)) {
|
||||
if (parseTypeList(results))
|
||||
if (parseFunctionResultTypes(results))
|
||||
return ParseFailure;
|
||||
}
|
||||
type = builder.getFunctionType(argTypes, results);
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
// RUN: mlir-opt %s | FileCheck %s
|
||||
// Verify the printed output can be parsed.
|
||||
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
||||
// TODO(b/123888077): The following fails due to constant with function pointer.
|
||||
// Disabled: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
|
||||
// Verify the generic form can be parsed.
|
||||
// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK: #map0 = (d0) -> (d0 + 1)
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ func @location_fused_missing_greater() {
|
|||
|
||||
func @location_fused_missing_metadata() {
|
||||
^bb:
|
||||
// expected-error@+1 {{expected type}}
|
||||
// expected-error@+1 {{expected non-function type}}
|
||||
return loc(fused<) // expected-error {{expected valid attribute metadata}}
|
||||
}
|
||||
|
||||
|
|
|
@ -209,7 +209,7 @@ func @func_with_ops(i32, i64) {
|
|||
// Comparisons must have the "predicate" attribute.
|
||||
func @func_with_ops(i32, i32) {
|
||||
^bb0(%a : i32, %b : i32):
|
||||
%r = cmpi %a, %b : i32 // expected-error {{expected type}}
|
||||
%r = cmpi %a, %b : i32 // expected-error {{expected non-function type}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
// Check different error cases.
|
||||
// -----
|
||||
|
||||
func @illegaltype(i) // expected-error {{expected type}}
|
||||
func @illegaltype(i) // expected-error {{expected non-function type}}
|
||||
|
||||
// -----
|
||||
|
||||
func @illegaltype() {
|
||||
%0 = constant splat<<vector 4 x f32>, 0> : vector<4 x f32> // expected-error {{expected type}}
|
||||
%0 = constant splat<<vector 4 x f32>, 0> : vector<4 x f32> // expected-error {{expected non-function type}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -227,7 +227,7 @@ func @incomplete_for() {
|
|||
// -----
|
||||
|
||||
func @nonconstant_step(%1 : i32) {
|
||||
for %2 = 1 to 5 step %1 { // expected-error {{expected type}}
|
||||
for %2 = 1 to 5 step %1 { // expected-error {{expected non-function type}}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -326,7 +326,7 @@ func @d() {return} // expected-error {{custom op 'func' is unknown}}
|
|||
|
||||
// -----
|
||||
|
||||
func @malformed_type(%a : intt) { // expected-error {{expected type}}
|
||||
func @malformed_type(%a : intt) { // expected-error {{expected non-function type}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -392,7 +392,7 @@ func @condbr_badtype() {
|
|||
^bb0:
|
||||
%c = "foo"() : () -> i1
|
||||
%a = "foo"() : () -> i32
|
||||
cond_br %c, ^bb0(%a, %a : i32, ^bb0) // expected-error {{expected type}}
|
||||
cond_br %c, ^bb0(%a, %a : i32, ^bb0) // expected-error {{expected non-function type}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -506,6 +506,18 @@ func @undefined_function() {
|
|||
|
||||
// -----
|
||||
|
||||
func @invalid_result_type() -> () -> () // expected-error {{expected a top level entity}}
|
||||
|
||||
// -----
|
||||
|
||||
func @func() -> (() -> ())
|
||||
func @referer() {
|
||||
%f = constant @func : () -> () -> () // expected-error {{reference to function with mismatched type}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#map1 = (i)[j] -> (i+j)
|
||||
|
||||
func @bound_symbol_mismatch(%N : index) {
|
||||
|
@ -538,7 +550,7 @@ func @large_bound() {
|
|||
// -----
|
||||
|
||||
func @max_in_upper_bound(%N : index) {
|
||||
for %i = 1 to max (i)->(N, 100) { //expected-error {{expected type}}
|
||||
for %i = 1 to max (i)->(N, 100) { //expected-error {{expected non-function type}}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -595,7 +607,7 @@ func @invalid_if_operands3(%N : index) {
|
|||
// expected-error@+1 {{expected '"' in string literal}}
|
||||
"J// -----
|
||||
func @calls(%arg0: i32) {
|
||||
// expected-error@+1 {{expected type}}
|
||||
// expected-error@+1 {{expected non-function type}}
|
||||
%z = "casdasda"(%x) : (ppop32) -> i32
|
||||
}
|
||||
// -----
|
||||
|
@ -767,7 +779,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined t
|
|||
|
||||
// -----
|
||||
|
||||
!missing_type_alias = type // expected-error@+2 {{expected type}}
|
||||
!missing_type_alias = type // expected-error@+2 {{expected non-function type}}
|
||||
|
||||
// -----
|
||||
|
||||
|
|
|
@ -182,6 +182,28 @@ func @func_with_two_args(%a : f16, %b : i8) -> (i1, i32) {
|
|||
return %c#0, %c#1 : i1, i32 // CHECK: return %0#0, %0#1 : i1, i32
|
||||
} // CHECK: }
|
||||
|
||||
// CHECK-LABEL: func @second_order_func() -> (() -> ()) {
|
||||
func @second_order_func() -> (() -> ()) {
|
||||
// CHECK-NEXT: %f = constant @emptyMLF : () -> ()
|
||||
%c = constant @emptyMLF : () -> ()
|
||||
// CHECK-NEXT: return %f : () -> ()
|
||||
return %c : () -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @third_order_func() -> (() -> (() -> ())) {
|
||||
func @third_order_func() -> (() -> (() -> ())) {
|
||||
// CHECK-NEXT: %f = constant @second_order_func : () -> (() -> ())
|
||||
%c = constant @second_order_func : () -> (() -> ())
|
||||
// CHECK-NEXT: return %f : () -> (() -> ())
|
||||
return %c : () -> (() -> ())
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @identity_functor(%arg0: () -> ()) -> (() -> ()) {
|
||||
func @identity_functor(%a : () -> ()) -> (() -> ()) {
|
||||
// CHECK-NEXT: return %arg0 : () -> ()
|
||||
return %a : () -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_ops_in_loop() {
|
||||
func @func_ops_in_loop() {
|
||||
// CHECK: %0 = "foo"() : () -> i64
|
||||
|
|
Loading…
Reference in New Issue