From 0126dcf1f0a132388286d3e979b50cc464c352fd Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 8 Aug 2019 09:41:48 -0700 Subject: [PATCH] Introduce support for variadic function signatures for the LLVM dialect LLVM function type has first-class support for variadic functions. In the current lowering pipeline, it is emulated using an attribute on functions of standard function type. In LLVMFuncOp that has LLVM function type, this can be modeled directly. Introduce parsing support for variadic arguments to the function and use it to support variadic function declarations in LLVMFuncOp. Function definitions are currently not supported as that would require modeling va_start/va_end LLVM intrinsics in the dialect and we don't yet have a consistent story for LLVM intrinsics. PiperOrigin-RevId: 262372651 --- mlir/include/mlir/IR/FunctionSupport.h | 19 +++++++---- mlir/include/mlir/IR/OpImplementation.h | 3 ++ mlir/include/mlir/LLVMIR/LLVMOps.td | 3 +- mlir/lib/IR/Function.cpp | 14 +++++--- mlir/lib/IR/FunctionSupport.cpp | 45 +++++++++++++++++++------ mlir/lib/LLVMIR/IR/LLVMDialect.cpp | 10 +++--- mlir/lib/Parser/Lexer.cpp | 16 +++++++++ mlir/lib/Parser/Lexer.h | 1 + mlir/lib/Parser/Parser.cpp | 5 +++ mlir/lib/Parser/TokenKinds.def | 1 + mlir/test/IR/invalid.mlir | 15 +++++++++ mlir/test/LLVMIR/func.mlir | 22 ++++++++++++ 12 files changed, 127 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h index 192f5dd33427..ec1200167924 100644 --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -55,23 +55,28 @@ inline ArrayRef getArgAttrs(Operation *op, unsigned index) { /// Callback type for `parseFunctionLikeOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of -/// function arguments and results; in case of error, it may populate the last +/// function arguments and results, the boolean operand is true if the function +/// should have variadic arguments; in case of error, it may populate the last /// argument with a message. -using FuncTypeBuilder = llvm::function_ref, - ArrayRef, std::string &)>; +using FuncTypeBuilder = llvm::function_ref, ArrayRef, bool, std::string &)>; /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of -/// input and output types. If the builder returns a null type, `result` will -/// not contain the `type` attribute. The caller can then add a type, report -/// the error or delegate the reporting to the op's verifier. +/// input and output types. If `allowVariadic` is set, the parser will accept +/// trailing ellipsis in the function signature and indicate to the builder +/// whether the function is variadic. If the builder returns a null type, +/// `result` will not contain the `type` attribute. The caller can then add a +/// type, report the error or delegate the reporting to the op's verifier. ParseResult parseFunctionLikeOp(OpAsmParser *parser, OperationState *result, + bool allowVariadic, FuncTypeBuilder funcTypeBuilder); /// Printer implementation for function-like operations. Accepts lists of /// argument and result types to use while printing. void printFunctionLikeOp(OpAsmPrinter *p, Operation *op, - ArrayRef argTypes, ArrayRef results); + ArrayRef argTypes, bool isVariadic, + ArrayRef results); } // namespace impl diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index cd26653eca06..23a0cd1e5b79 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -260,6 +260,9 @@ public: /// Parse a `]` token if present. virtual ParseResult parseOptionalRSquare() = 0; + /// Parse a `...` token if present; + virtual ParseResult parseOptionalEllipsis() = 0; + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td index c0a30b1834cf..7c23330eae87 100644 --- a/mlir/include/mlir/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -344,7 +344,8 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func", let verifier = [{ return ::verify(*this); }]; let printer = [{ printLLVMFuncOp(p, *this); }]; let parser = [{ - return impl::parseFunctionLikeOp(parser, result, buildLLVMFunctionType); + return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/true, + buildLLVMFunctionType); }]; } diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 42bc03fc77f2..e4d1960a40d2 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -76,15 +76,19 @@ void FuncOp::build(Builder *builder, OperationState *result, StringRef name, /// Parsing/Printing methods. ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) { - return impl::parseFunctionLikeOp( - parser, result, - [](Builder &builder, ArrayRef argTypes, ArrayRef results, - std::string &) { return builder.getFunctionType(argTypes, results); }); + auto buildFuncType = [](Builder &builder, ArrayRef argTypes, + ArrayRef results, bool, std::string &) { + return builder.getFunctionType(argTypes, results); + }; + + return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false, + buildFuncType); } void FuncOp::print(OpAsmPrinter *p) { FunctionType fnType = getType(); - impl::printFunctionLikeOp(p, *this, fnType.getInputs(), fnType.getResults()); + impl::printFunctionLikeOp(p, *this, fnType.getInputs(), /*isVariadic=*/false, + fnType.getResults()); } LogicalResult FuncOp::verify() { diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp index 92285e4ba21f..7416e64cd913 100644 --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -22,9 +22,11 @@ using namespace mlir; static ParseResult -parseArgumentList(OpAsmParser *parser, SmallVectorImpl &argTypes, +parseArgumentList(OpAsmParser *parser, bool allowVariadic, + SmallVectorImpl &argTypes, SmallVectorImpl &argNames, - SmallVectorImpl> &argAttrs) { + SmallVectorImpl> &argAttrs, + bool &isVariadic) { if (parser->parseLParen()) return failure(); @@ -47,6 +49,9 @@ parseArgumentList(OpAsmParser *parser, SmallVectorImpl &argTypes, if (parser->parseColonType(argumentType)) return failure(); + } else if (allowVariadic && succeeded(parser->parseOptionalEllipsis())) { + isVariadic = true; + return success(); } else if (!argNames.empty()) { // Reject this if the preceding argument had a name. return parser->emitError(loc, "expected SSA identifier"); @@ -68,8 +73,15 @@ parseArgumentList(OpAsmParser *parser, SmallVectorImpl &argTypes, // Parse the function arguments. if (parser->parseOptionalRParen()) { do { + unsigned numTypedArguments = argTypes.size(); if (parseArgument()) return failure(); + + llvm::SMLoc loc = parser->getCurrentLocation(); + if (argTypes.size() == numTypedArguments && + succeeded(parser->parseOptionalComma())) + return parser->emitError( + loc, "variadic arguments must be in the end of the argument list"); } while (succeeded(parser->parseOptionalComma())); parser->parseRParen(); } @@ -80,11 +92,13 @@ parseArgumentList(OpAsmParser *parser, SmallVectorImpl &argTypes, /// Parse a function signature, starting with a name and including the /// parameter list. static ParseResult parseFunctionSignature( - OpAsmParser *parser, SmallVectorImpl &argNames, + OpAsmParser *parser, bool allowVariadic, + SmallVectorImpl &argNames, SmallVectorImpl &argTypes, - SmallVectorImpl> &argAttrs, + SmallVectorImpl> &argAttrs, bool &isVariadic, SmallVectorImpl &results) { - if (parseArgumentList(parser, argTypes, argNames, argAttrs)) + if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs, + isVariadic)) return failure(); // Parse the return types if present. return parser->parseOptionalArrowTypeList(results); @@ -94,6 +108,7 @@ static ParseResult parseFunctionSignature( /// to construct the custom function type given lists of input and output types. ParseResult mlir::impl::parseFunctionLikeOp(OpAsmParser *parser, OperationState *result, + bool allowVariadic, mlir::impl::FuncTypeBuilder funcTypeBuilder) { SmallVector entryArgs; SmallVector, 4> argAttrs; @@ -111,11 +126,14 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser *parser, OperationState *result, // Parse the function signature. auto signatureLocation = parser->getCurrentLocation(); - if (parseFunctionSignature(parser, entryArgs, argTypes, argAttrs, results)) + bool isVariadic = false; + if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, + argAttrs, isVariadic, results)) return failure(); std::string errorMessage; - if (auto type = funcTypeBuilder(builder, argTypes, results, errorMessage)) + if (auto type = + funcTypeBuilder(builder, argTypes, results, isVariadic, errorMessage)) result->addAttribute(getTypeAttrName(), builder.getTypeAttr(type)); else return parser->emitError(signatureLocation) @@ -147,7 +165,8 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser *parser, OperationState *result, /// Print the signature of the function-like operation `op`. Assumes `op` has /// the FunctionLike trait and passed the verification. static void printSignature(OpAsmPrinter *p, Operation *op, - ArrayRef argTypes, ArrayRef results) { + ArrayRef argTypes, bool isVariadic, + ArrayRef results) { Region &body = op->getRegion(0); bool isExternal = body.empty(); @@ -165,6 +184,12 @@ static void printSignature(OpAsmPrinter *p, Operation *op, p->printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i)); } + if (isVariadic) { + if (!argTypes.empty()) + *p << ", "; + *p << "..."; + } + *p << ')'; p->printOptionalArrowTypeList(results); } @@ -172,7 +197,7 @@ static void printSignature(OpAsmPrinter *p, Operation *op, /// Printer implementation for function-like operations. Accepts lists of /// argument and result types to use while printing. void mlir::impl::printFunctionLikeOp(OpAsmPrinter *p, Operation *op, - ArrayRef argTypes, + ArrayRef argTypes, bool isVariadic, ArrayRef results) { // Print the operation and the function name. auto funcName = @@ -181,7 +206,7 @@ void mlir::impl::printFunctionLikeOp(OpAsmPrinter *p, Operation *op, *p << op->getName() << " @" << funcName; // Print the signature. - printSignature(p, op, argTypes, results); + printSignature(p, op, argTypes, isVariadic, results); // Print out function attributes, if present. SmallVector ignoredAttrs = { diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 1315fdd6bd22..9c4933830bd1 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -730,7 +730,7 @@ void LLVMFuncOp::build(Builder *builder, OperationState *result, StringRef name, // Returns a null type if any of the types provided are non-LLVM types, or if // there is more than one output type. static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs, - ArrayRef outputs, + ArrayRef outputs, bool isVariadic, std::string &errorMessage) { if (outputs.size() > 1) { errorMessage = "expected zero or one function result"; @@ -761,8 +761,7 @@ static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs, errorMessage = "expected LLVM type for function results"; return {}; } - return LLVMType::getFunctionTy(llvmOutput, llvmInputs, - /*isVarArg=*/false); + return LLVMType::getFunctionTy(llvmOutput, llvmInputs, isVariadic); } // Print the LLVMFuncOp. Collects argument and result types and passes them @@ -779,7 +778,7 @@ static void printLLVMFuncOp(OpAsmPrinter *p, LLVMFuncOp op) { if (!returnType.getUnderlyingType()->isVoidTy()) resTypes.push_back(returnType); - impl::printFunctionLikeOp(p, op, argTypes, resTypes); + impl::printFunctionLikeOp(p, op, argTypes, op.isVarArg(), resTypes); } // Hook for OpTrait::FunctionLike, called after verifying that the 'type' @@ -804,6 +803,9 @@ static LogicalResult verify(LLVMFuncOp op) { if (op.isExternal()) return success(); + if (op.isVarArg()) + return op.emitOpError("only external functions can be variadic"); + auto *funcType = cast(op.getType().getUnderlyingType()); unsigned numArguments = funcType->getNumParams(); Block &entryBlock = op.front(); diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp index 29d093d3af53..4d95e4b343d7 100644 --- a/mlir/lib/Parser/Lexer.cpp +++ b/mlir/lib/Parser/Lexer.cpp @@ -106,6 +106,8 @@ Token Lexer::lexToken() { return formToken(Token::colon, tokStart); case ',': return formToken(Token::comma, tokStart); + case '.': + return lexEllipsis(tokStart); case '(': return formToken(Token::l_paren, tokStart); case ')': @@ -382,3 +384,17 @@ Token Lexer::lexString(const char *tokStart) { } } } + +/// Lex an ellipsis. +/// +/// ellipsis ::= '...' +/// +Token Lexer::lexEllipsis(const char *tokStart) { + assert(curPtr[-1] == '.'); + + if (curPtr == curBuffer.end() || *curPtr != '.' || *(curPtr + 1) != '.') + return emitError(curPtr, "expected three consecutive dots for an ellipsis"); + + curPtr += 2; + return formToken(Token::ellipsis, tokStart); +} diff --git a/mlir/lib/Parser/Lexer.h b/mlir/lib/Parser/Lexer.h index 7b731be54967..0db81827456f 100644 --- a/mlir/lib/Parser/Lexer.h +++ b/mlir/lib/Parser/Lexer.h @@ -60,6 +60,7 @@ private: Token lexPrefixedIdentifier(const char *tokStart); Token lexNumber(const char *tokStart); Token lexString(const char *tokStart); + Token lexEllipsis(const char *tokStart); const llvm::SourceMgr &sourceMgr; MLIRContext *context; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index e4628cc6d3ed..5e722ad649b7 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3319,6 +3319,11 @@ public: return success(parser.consumeIf(Token::r_square)); } + /// Parses a `...` if present. + ParseResult parseOptionalEllipsis() override { + return success(parser.consumeIf(Token::ellipsis)); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 18067a8e77dc..32e9b1209380 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -78,6 +78,7 @@ TOK_PUNCTUATION(r_square, "]") TOK_PUNCTUATION(less, "<") TOK_PUNCTUATION(greater, ">") TOK_PUNCTUATION(equal, "=") +TOK_PUNCTUATION(ellipsis, "...") // TODO: More punctuation. // Operators. diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index da8b0911b593..2f4698b085bc 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -443,6 +443,21 @@ func @return_inside_loop() { // ----- +// expected-error@+1 {{expected three consecutive dots for an ellipsis}} +func @malformed_ellipsis_one(.) + +// ----- + +// expected-error@+1 {{expected three consecutive dots for an ellipsis}} +func @malformed_ellipsis_two(..) + +// ----- + +// expected-error@+1 {{expected non-function type}} +func @func_variadic(...) + +// ----- + func @redef() // expected-note {{see existing symbol definition here}} func @redef() // expected-error {{redefinition of symbol named 'redef'}} diff --git a/mlir/test/LLVMIR/func.mlir b/mlir/test/LLVMIR/func.mlir index f0dc3f5d7074..348ad7ae1b57 100644 --- a/mlir/test/LLVMIR/func.mlir +++ b/mlir/test/LLVMIR/func.mlir @@ -86,6 +86,12 @@ module { attributes {foo = 42 : i32} { llvm.return } + + // CHECK: llvm.func @variadic(...) + llvm.func @variadic(...) + + // CHECK: llvm.func @variadic_args(!llvm.i32, !llvm.i32, ...) + llvm.func @variadic_args(!llvm.i32, !llvm.i32, ...) } // ----- @@ -166,3 +172,19 @@ module { // expected-error@+1 {{failed to construct function type: expected zero or one function result}} llvm.func @foo() -> (!llvm.i64, !llvm.i64) } + +// ----- + +module { + // expected-error@+1 {{only external functions can be variadic}} + llvm.func @variadic_def(...) { + llvm.return + } +} + +// ----- + +module { + // expected-error@+1 {{variadic arguments must be in the end of the argument list}} + llvm.func @variadic_inside(%arg0: !llvm.i32, ..., %arg1: !llvm.i32) +}