forked from OSchip/llvm-project
Print parens around the return type of a function if it is also a function type
Existing type syntax contains the following productions: function-type ::= type-list-parens `->` type-list type-list ::= type | type-list-parens type ::= <..> | function-type Due to these rules, when the parser sees `->` followed by `(`, it cannot disambiguate if `(` starts a parenthesized list of function result types, or a parenthesized list of operands of another function type, returned from the current function. We would need an unknown amount of lookahead to try to find the `->` at the right level of function nesting to differentiate between type lists and singular function types. Instead, require the result type of the function that is a function type itself to be always parenthesized, at the syntax level. Update the spec and the parser to correspond to the production rule names used in the spec (although it would have worked without modifications). Fix the function type parsing bug in the process, as it used to accept the non-parenthesized list of types for arguments, disallowed by the spec. PiperOrigin-RevId: 232528361
This commit is contained in:
parent
1b1f293a5d
commit
40d5d09f9d
|
@ -504,18 +504,18 @@ have application-specific semantics. For example, MLIR supports a set of
|
||||||
[dialect-specific types](#dialect-specific-types).
|
[dialect-specific types](#dialect-specific-types).
|
||||||
|
|
||||||
``` {.ebnf}
|
``` {.ebnf}
|
||||||
type ::= integer-type
|
type ::= non-function-type
|
||||||
| index-type
|
|
||||||
| float-type
|
|
||||||
| vector-type
|
|
||||||
| tensor-type
|
|
||||||
| memref-type
|
|
||||||
| function-type
|
| function-type
|
||||||
| dialect-type
|
|
||||||
| type-alias
|
|
||||||
|
|
||||||
// MLIR doesn't have a tuple type but functions can return multiple values.
|
non-function-type ::= integer-type
|
||||||
type-list ::= type-list-parens | type
|
| index-type
|
||||||
|
| float-type
|
||||||
|
| vector-type
|
||||||
|
| tensor-type
|
||||||
|
| memref-type
|
||||||
|
| dialect-type
|
||||||
|
| type-alias
|
||||||
|
|
||||||
type-list-no-parens ::= type (`,` type)*
|
type-list-no-parens ::= type (`,` type)*
|
||||||
type-list-parens ::= `(` `)`
|
type-list-parens ::= `(` `)`
|
||||||
| `(` type-list-no-parens `)`
|
| `(` type-list-no-parens `)`
|
||||||
|
@ -559,7 +559,11 @@ Builtin types consist of only the types needed for the validity of the IR.
|
||||||
Syntax:
|
Syntax:
|
||||||
|
|
||||||
``` {.ebnf}
|
``` {.ebnf}
|
||||||
function-type ::= type-list-parens `->` type-list
|
// MLIR doesn't have a tuple type but functions can return multiple values.
|
||||||
|
function-result-type ::= type-list-parens
|
||||||
|
| non-function-type
|
||||||
|
|
||||||
|
function-type ::= type-list-parens `->` function-result-type
|
||||||
```
|
```
|
||||||
|
|
||||||
MLIR supports first-class functions: the
|
MLIR supports first-class functions: the
|
||||||
|
@ -897,7 +901,7 @@ associated attributes according to the following grammar:
|
||||||
``` {.ebnf}
|
``` {.ebnf}
|
||||||
function ::= `func` function-signature function-attributes? function-body?
|
function ::= `func` function-signature function-attributes? function-body?
|
||||||
|
|
||||||
function-signature ::= function-id `(` argument-list `)` (`->` type-list)?
|
function-signature ::= function-id `(` argument-list `)` (`->` function-result-type)?
|
||||||
argument-list ::= named-argument (`,` named-argument)* | /*empty*/
|
argument-list ::= named-argument (`,` named-argument)* | /*empty*/
|
||||||
argument-list ::= type (`,` type)* | /*empty*/ named-argument ::= ssa-id `:`
|
argument-list ::= type (`,` type)* | /*empty*/ named-argument ::= ssa-id `:`
|
||||||
type
|
type
|
||||||
|
|
|
@ -746,7 +746,7 @@ void ModulePrinter::printType(Type type) {
|
||||||
interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
|
interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
|
||||||
os << ") -> ";
|
os << ") -> ";
|
||||||
auto results = func.getResults();
|
auto results = func.getResults();
|
||||||
if (results.size() == 1)
|
if (results.size() == 1 && !results[0].isa<FunctionType>())
|
||||||
os << results[0];
|
os << results[0];
|
||||||
else {
|
else {
|
||||||
os << '(';
|
os << '(';
|
||||||
|
@ -1313,10 +1313,17 @@ void FunctionPrinter::printFunctionSignature() {
|
||||||
switch (fnType.getResults().size()) {
|
switch (fnType.getResults().size()) {
|
||||||
case 0:
|
case 0:
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1: {
|
||||||
os << " -> ";
|
os << " -> ";
|
||||||
printType(fnType.getResults()[0]);
|
auto resultType = fnType.getResults()[0];
|
||||||
|
bool resultIsFunc = resultType.isa<FunctionType>();
|
||||||
|
if (resultIsFunc)
|
||||||
|
os << '(';
|
||||||
|
printType(resultType);
|
||||||
|
if (resultIsFunc)
|
||||||
|
os << ')';
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
os << " -> (";
|
os << " -> (";
|
||||||
interleaveComma(fnType.getResults(),
|
interleaveComma(fnType.getResults(),
|
||||||
|
@ -1482,7 +1489,8 @@ void FunctionPrinter::printGenericOp(const Instruction *op) {
|
||||||
[&](const Value *value) { printType(value->getType()); });
|
[&](const Value *value) { printType(value->getType()); });
|
||||||
os << ") -> ";
|
os << ") -> ";
|
||||||
|
|
||||||
if (op->getNumResults() == 1) {
|
if (op->getNumResults() == 1 &&
|
||||||
|
!op->getResult(0)->getType().isa<FunctionType>()) {
|
||||||
printType(op->getResult(0)->getType());
|
printType(op->getResult(0)->getType());
|
||||||
} else {
|
} else {
|
||||||
os << '(';
|
os << '(';
|
||||||
|
|
|
@ -187,9 +187,11 @@ public:
|
||||||
Type parseTensorType();
|
Type parseTensorType();
|
||||||
Type parseMemRefType();
|
Type parseMemRefType();
|
||||||
Type parseFunctionType();
|
Type parseFunctionType();
|
||||||
|
Type parseNonFunctionType();
|
||||||
Type parseType();
|
Type parseType();
|
||||||
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
|
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
|
||||||
ParseResult parseTypeList(SmallVectorImpl<Type> &elements);
|
ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
|
||||||
|
ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
|
||||||
|
|
||||||
// Attribute parsing.
|
// Attribute parsing.
|
||||||
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||||
|
@ -308,32 +310,29 @@ ParseResult Parser::parseCommaSeparatedListUntil(
|
||||||
// Type Parsing
|
// Type Parsing
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Parse an arbitrary type.
|
/// Parse any type except the function type.
|
||||||
///
|
///
|
||||||
/// type ::= integer-type
|
/// non-function-type ::= integer-type
|
||||||
/// | index-type
|
/// | index-type
|
||||||
/// | float-type
|
/// | float-type
|
||||||
/// | extended-type
|
/// | extended-type
|
||||||
/// | vector-type
|
/// | vector-type
|
||||||
/// | tensor-type
|
/// | tensor-type
|
||||||
/// | memref-type
|
/// | memref-type
|
||||||
/// | function-type
|
|
||||||
///
|
///
|
||||||
/// index-type ::= `index`
|
/// index-type ::= `index`
|
||||||
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
|
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
|
||||||
///
|
///
|
||||||
Type Parser::parseType() {
|
Type Parser::parseNonFunctionType() {
|
||||||
switch (getToken().getKind()) {
|
switch (getToken().getKind()) {
|
||||||
default:
|
default:
|
||||||
return (emitError("expected type"), nullptr);
|
return (emitError("expected non-function type"), nullptr);
|
||||||
case Token::kw_memref:
|
case Token::kw_memref:
|
||||||
return parseMemRefType();
|
return parseMemRefType();
|
||||||
case Token::kw_tensor:
|
case Token::kw_tensor:
|
||||||
return parseTensorType();
|
return parseTensorType();
|
||||||
case Token::kw_vector:
|
case Token::kw_vector:
|
||||||
return parseVectorType();
|
return parseVectorType();
|
||||||
case Token::l_paren:
|
|
||||||
return parseFunctionType();
|
|
||||||
// integer-type
|
// integer-type
|
||||||
case Token::inttype: {
|
case Token::inttype: {
|
||||||
auto width = getToken().getIntTypeBitwidth();
|
auto width = getToken().getIntTypeBitwidth();
|
||||||
|
@ -369,6 +368,17 @@ Type Parser::parseType() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse an arbitrary type.
|
||||||
|
///
|
||||||
|
/// type ::= function-type
|
||||||
|
/// | non-function-type
|
||||||
|
///
|
||||||
|
Type Parser::parseType() {
|
||||||
|
if (getToken().is(Token::l_paren))
|
||||||
|
return parseFunctionType();
|
||||||
|
return parseNonFunctionType();
|
||||||
|
}
|
||||||
|
|
||||||
/// Parse a vector type.
|
/// Parse a vector type.
|
||||||
///
|
///
|
||||||
/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
|
/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
|
||||||
|
@ -640,9 +650,9 @@ Type Parser::parseFunctionType() {
|
||||||
assert(getToken().is(Token::l_paren));
|
assert(getToken().is(Token::l_paren));
|
||||||
|
|
||||||
SmallVector<Type, 4> arguments, results;
|
SmallVector<Type, 4> arguments, results;
|
||||||
if (parseTypeList(arguments) ||
|
if (parseTypeListParens(arguments) ||
|
||||||
parseToken(Token::arrow, "expected '->' in function type") ||
|
parseToken(Token::arrow, "expected '->' in function type") ||
|
||||||
parseTypeList(results))
|
parseFunctionResultTypes(results))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
return builder.getFunctionType(arguments, results);
|
return builder.getFunctionType(arguments, results);
|
||||||
|
@ -663,27 +673,38 @@ ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
|
||||||
return parseCommaSeparatedList(parseElt);
|
return parseCommaSeparatedList(parseElt);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse a "type list", which is a singular type, or a parenthesized list of
|
/// Parse a parenthesized list of types.
|
||||||
/// types.
|
|
||||||
///
|
///
|
||||||
/// type-list ::= type-list-parens | type
|
|
||||||
/// type-list-parens ::= `(` `)`
|
/// type-list-parens ::= `(` `)`
|
||||||
/// | `(` type-list-no-parens `)`
|
/// | `(` type-list-no-parens `)`
|
||||||
///
|
///
|
||||||
ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
|
ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
|
||||||
auto parseElt = [&]() -> ParseResult {
|
if (parseToken(Token::l_paren, "expected '('"))
|
||||||
auto elt = parseType();
|
|
||||||
elements.push_back(elt);
|
|
||||||
return elt ? ParseSuccess : ParseFailure;
|
|
||||||
};
|
|
||||||
|
|
||||||
// If there is no parens, then it must be a singular type.
|
|
||||||
if (!consumeIf(Token::l_paren))
|
|
||||||
return parseElt();
|
|
||||||
|
|
||||||
if (parseCommaSeparatedListUntil(Token::r_paren, parseElt))
|
|
||||||
return ParseFailure;
|
return ParseFailure;
|
||||||
|
|
||||||
|
// Handle empty lists.
|
||||||
|
if (getToken().is(Token::r_paren))
|
||||||
|
return consumeToken(), ParseSuccess;
|
||||||
|
|
||||||
|
if (parseTypeListNoParens(elements) ||
|
||||||
|
parseToken(Token::r_paren, "expected ')'"))
|
||||||
|
return ParseFailure;
|
||||||
|
return ParseSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a function result type.
|
||||||
|
///
|
||||||
|
/// function-result-type ::= type-list-parens
|
||||||
|
/// | non-function-type
|
||||||
|
///
|
||||||
|
ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
|
||||||
|
if (getToken().is(Token::l_paren))
|
||||||
|
return parseTypeListParens(elements);
|
||||||
|
|
||||||
|
Type t = parseNonFunctionType();
|
||||||
|
if (!t)
|
||||||
|
return ParseFailure;
|
||||||
|
elements.push_back(t);
|
||||||
return ParseSuccess;
|
return ParseSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3489,7 +3510,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
|
||||||
// Parse the return type if present.
|
// Parse the return type if present.
|
||||||
SmallVector<Type, 4> results;
|
SmallVector<Type, 4> results;
|
||||||
if (consumeIf(Token::arrow)) {
|
if (consumeIf(Token::arrow)) {
|
||||||
if (parseTypeList(results))
|
if (parseFunctionResultTypes(results))
|
||||||
return ParseFailure;
|
return ParseFailure;
|
||||||
}
|
}
|
||||||
type = builder.getFunctionType(argTypes, results);
|
type = builder.getFunctionType(argTypes, results);
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
// RUN: mlir-opt %s | FileCheck %s
|
// RUN: mlir-opt %s | FileCheck %s
|
||||||
// Verify the printed output can be parsed.
|
// Verify the printed output can be parsed.
|
||||||
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
||||||
// TODO(b/123888077): The following fails due to constant with function pointer.
|
// Verify the generic form can be parsed.
|
||||||
// Disabled: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
|
// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
|
||||||
|
|
||||||
// CHECK: #map0 = (d0) -> (d0 + 1)
|
// CHECK: #map0 = (d0) -> (d0 + 1)
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ func @location_fused_missing_greater() {
|
||||||
|
|
||||||
func @location_fused_missing_metadata() {
|
func @location_fused_missing_metadata() {
|
||||||
^bb:
|
^bb:
|
||||||
// expected-error@+1 {{expected type}}
|
// expected-error@+1 {{expected non-function type}}
|
||||||
return loc(fused<) // expected-error {{expected valid attribute metadata}}
|
return loc(fused<) // expected-error {{expected valid attribute metadata}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -209,7 +209,7 @@ func @func_with_ops(i32, i64) {
|
||||||
// Comparisons must have the "predicate" attribute.
|
// Comparisons must have the "predicate" attribute.
|
||||||
func @func_with_ops(i32, i32) {
|
func @func_with_ops(i32, i32) {
|
||||||
^bb0(%a : i32, %b : i32):
|
^bb0(%a : i32, %b : i32):
|
||||||
%r = cmpi %a, %b : i32 // expected-error {{expected type}}
|
%r = cmpi %a, %b : i32 // expected-error {{expected non-function type}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
|
@ -3,12 +3,12 @@
|
||||||
// Check different error cases.
|
// Check different error cases.
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @illegaltype(i) // expected-error {{expected type}}
|
func @illegaltype(i) // expected-error {{expected non-function type}}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @illegaltype() {
|
func @illegaltype() {
|
||||||
%0 = constant splat<<vector 4 x f32>, 0> : vector<4 x f32> // expected-error {{expected type}}
|
%0 = constant splat<<vector 4 x f32>, 0> : vector<4 x f32> // expected-error {{expected non-function type}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -227,7 +227,7 @@ func @incomplete_for() {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @nonconstant_step(%1 : i32) {
|
func @nonconstant_step(%1 : i32) {
|
||||||
for %2 = 1 to 5 step %1 { // expected-error {{expected type}}
|
for %2 = 1 to 5 step %1 { // expected-error {{expected non-function type}}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -326,7 +326,7 @@ func @d() {return} // expected-error {{custom op 'func' is unknown}}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @malformed_type(%a : intt) { // expected-error {{expected type}}
|
func @malformed_type(%a : intt) { // expected-error {{expected non-function type}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -392,7 +392,7 @@ func @condbr_badtype() {
|
||||||
^bb0:
|
^bb0:
|
||||||
%c = "foo"() : () -> i1
|
%c = "foo"() : () -> i1
|
||||||
%a = "foo"() : () -> i32
|
%a = "foo"() : () -> i32
|
||||||
cond_br %c, ^bb0(%a, %a : i32, ^bb0) // expected-error {{expected type}}
|
cond_br %c, ^bb0(%a, %a : i32, ^bb0) // expected-error {{expected non-function type}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -506,6 +506,18 @@ func @undefined_function() {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @invalid_result_type() -> () -> () // expected-error {{expected a top level entity}}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @func() -> (() -> ())
|
||||||
|
func @referer() {
|
||||||
|
%f = constant @func : () -> () -> () // expected-error {{reference to function with mismatched type}}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
#map1 = (i)[j] -> (i+j)
|
#map1 = (i)[j] -> (i+j)
|
||||||
|
|
||||||
func @bound_symbol_mismatch(%N : index) {
|
func @bound_symbol_mismatch(%N : index) {
|
||||||
|
@ -538,7 +550,7 @@ func @large_bound() {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @max_in_upper_bound(%N : index) {
|
func @max_in_upper_bound(%N : index) {
|
||||||
for %i = 1 to max (i)->(N, 100) { //expected-error {{expected type}}
|
for %i = 1 to max (i)->(N, 100) { //expected-error {{expected non-function type}}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -595,7 +607,7 @@ func @invalid_if_operands3(%N : index) {
|
||||||
// expected-error@+1 {{expected '"' in string literal}}
|
// expected-error@+1 {{expected '"' in string literal}}
|
||||||
"J// -----
|
"J// -----
|
||||||
func @calls(%arg0: i32) {
|
func @calls(%arg0: i32) {
|
||||||
// expected-error@+1 {{expected type}}
|
// expected-error@+1 {{expected non-function type}}
|
||||||
%z = "casdasda"(%x) : (ppop32) -> i32
|
%z = "casdasda"(%x) : (ppop32) -> i32
|
||||||
}
|
}
|
||||||
// -----
|
// -----
|
||||||
|
@ -767,7 +779,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined t
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
!missing_type_alias = type // expected-error@+2 {{expected type}}
|
!missing_type_alias = type // expected-error@+2 {{expected non-function type}}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
|
@ -182,6 +182,28 @@ func @func_with_two_args(%a : f16, %b : i8) -> (i1, i32) {
|
||||||
return %c#0, %c#1 : i1, i32 // CHECK: return %0#0, %0#1 : i1, i32
|
return %c#0, %c#1 : i1, i32 // CHECK: return %0#0, %0#1 : i1, i32
|
||||||
} // CHECK: }
|
} // CHECK: }
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @second_order_func() -> (() -> ()) {
|
||||||
|
func @second_order_func() -> (() -> ()) {
|
||||||
|
// CHECK-NEXT: %f = constant @emptyMLF : () -> ()
|
||||||
|
%c = constant @emptyMLF : () -> ()
|
||||||
|
// CHECK-NEXT: return %f : () -> ()
|
||||||
|
return %c : () -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @third_order_func() -> (() -> (() -> ())) {
|
||||||
|
func @third_order_func() -> (() -> (() -> ())) {
|
||||||
|
// CHECK-NEXT: %f = constant @second_order_func : () -> (() -> ())
|
||||||
|
%c = constant @second_order_func : () -> (() -> ())
|
||||||
|
// CHECK-NEXT: return %f : () -> (() -> ())
|
||||||
|
return %c : () -> (() -> ())
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @identity_functor(%arg0: () -> ()) -> (() -> ()) {
|
||||||
|
func @identity_functor(%a : () -> ()) -> (() -> ()) {
|
||||||
|
// CHECK-NEXT: return %arg0 : () -> ()
|
||||||
|
return %a : () -> ()
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @func_ops_in_loop() {
|
// CHECK-LABEL: func @func_ops_in_loop() {
|
||||||
func @func_ops_in_loop() {
|
func @func_ops_in_loop() {
|
||||||
// CHECK: %0 = "foo"() : () -> i64
|
// CHECK: %0 = "foo"() : () -> i64
|
||||||
|
|
Loading…
Reference in New Issue