[mlir:OpAsm] Factor out the common bits of (Op/Dialect)Asm(Parser/Printer)

This has a few benefits:
* It allows for defining parsers/printer code blocks that
  can be shared between operations and attribute/types.
* It removes the weird duplication of generic parser/printer hooks,
  which means that newly added hooks only require touching one class.

Differential Revision: https://reviews.llvm.org/D110375
This commit is contained in:
River Riddle 2021-09-24 19:56:01 +00:00
parent 62cc6b0da2
commit 531206310a
8 changed files with 1014 additions and 1378 deletions

View File

@ -15,14 +15,9 @@
#define MLIR_IR_DIALECTIMPLEMENTATION_H
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
class Builder;
//===----------------------------------------------------------------------===//
// DialectAsmPrinter
//===----------------------------------------------------------------------===//
@ -30,360 +25,26 @@ class Builder;
/// This is a pure-virtual base class that exposes the asmprinter hooks
/// necessary to implement a custom printAttribute/printType() method on a
/// dialect.
class DialectAsmPrinter {
class DialectAsmPrinter : public AsmPrinter {
public:
DialectAsmPrinter() {}
virtual ~DialectAsmPrinter();
virtual raw_ostream &getStream() const = 0;
/// Print the given attribute to the stream.
virtual void printAttribute(Attribute attr) = 0;
/// Print the given attribute without its type. The corresponding parser must
/// provide a valid type for the attribute.
virtual void printAttributeWithoutType(Attribute attr) = 0;
/// Print the given floating point value in a stabilized form that can be
/// roundtripped through the IR. This is the companion to the 'parseFloat'
/// hook on the DialectAsmParser.
virtual void printFloat(const APFloat &value) = 0;
/// Print the given type to the stream.
virtual void printType(Type type) = 0;
private:
DialectAsmPrinter(const DialectAsmPrinter &) = delete;
void operator=(const DialectAsmPrinter &) = delete;
using AsmPrinter::AsmPrinter;
~DialectAsmPrinter() override;
};
// Make the implementations convenient to use.
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) {
p.printAttribute(attr);
return p;
}
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p,
const APFloat &value) {
p.printFloat(value);
return p;
}
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) {
return p << APFloat(value);
}
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) {
return p << APFloat(value);
}
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) {
p.printType(type);
return p;
}
// Support printing anything that isn't convertible to one of the above types,
// even if it isn't exactly one of them. For example, we want to print
// FunctionType with the Type version above, not have it match this.
template <typename T, typename std::enable_if<
!std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, Type &>::value &&
!std::is_convertible<T &, APFloat &>::value &&
!llvm::is_one_of<T, double, float>::value,
T>::type * = nullptr>
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) {
p.getStream() << other;
return p;
}
//===----------------------------------------------------------------------===//
// DialectAsmParser
//===----------------------------------------------------------------------===//
/// The DialectAsmParser has methods for interacting with the asm parser:
/// parsing things from it, emitting errors etc. It has an intentionally
/// high-level API that is designed to reduce/constrain syntax innovation in
/// individual attributes or types.
class DialectAsmParser {
/// The DialectAsmParser has methods for interacting with the asm parser when
/// parsing attributes and types.
class DialectAsmParser : public AsmParser {
public:
virtual ~DialectAsmParser();
/// Emit a diagnostic at the specified location and return failure.
virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
const Twine &message = {}) = 0;
/// Return a builder which provides useful access to MLIRContext, global
/// objects like types and attributes.
virtual Builder &getBuilder() const = 0;
/// Get the location of the next token and store it into the argument. This
/// always succeeds.
virtual llvm::SMLoc getCurrentLocation() = 0;
ParseResult getCurrentLocation(llvm::SMLoc *loc) {
*loc = getCurrentLocation();
return success();
}
/// Return the location of the original name token.
virtual llvm::SMLoc getNameLoc() const = 0;
/// Re-encode the given source location as an MLIR location and return it.
/// Note: This method should only be used when a `Location` is necessary, as
/// the encoding process is not efficient. In other cases a more suitable
/// alternative should be used, such as the `getChecked` methods defined
/// below.
virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
using AsmParser::AsmParser;
~DialectAsmParser() override;
/// Returns the full specification of the symbol being parsed. This allows for
/// using a separate parser if necessary.
virtual StringRef getFullSymbolSpec() const = 0;
// These methods emit an error and return failure or success. This allows
// these to be chained together into a linear sequence of || expressions in
// many cases.
/// Parse a floating point value from the stream.
virtual ParseResult parseFloat(double &result) = 0;
/// Parse an integer value from the stream.
template <typename IntT>
ParseResult parseInteger(IntT &result) {
auto loc = getCurrentLocation();
OptionalParseResult parseResult = parseOptionalInteger(result);
if (!parseResult.hasValue())
return emitError(loc, "expected integer value");
return *parseResult;
}
/// Parse an optional integer value from the stream.
virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
template <typename IntT>
OptionalParseResult parseOptionalInteger(IntT &result) {
auto loc = getCurrentLocation();
// Parse the unsigned variant.
APInt uintResult;
OptionalParseResult parseResult = parseOptionalInteger(uintResult);
if (!parseResult.hasValue() || failed(*parseResult))
return parseResult;
// Try to convert to the provided integer type. sextOrTrunc is correct even
// for unsigned types because parseOptionalInteger ensures the sign bit is
// zero for non-negated integers.
result =
(IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
if (APInt(uintResult.getBitWidth(), result) != uintResult)
return emitError(loc, "integer value too large");
return success();
}
/// Invoke the `getChecked` method of the given Attribute or Type class, using
/// the provided location to emit errors in the case of failure. Note that
/// unlike `OpBuilder::getType`, this method does not implicitly insert a
/// context parameter.
template <typename T, typename... ParamsT>
T getChecked(llvm::SMLoc loc, ParamsT &&... params) {
return T::getChecked([&] { return emitError(loc); },
std::forward<ParamsT>(params)...);
}
/// A variant of `getChecked` that uses the result of `getNameLoc` to emit
/// errors.
template <typename T, typename... ParamsT>
T getChecked(ParamsT &&... params) {
return T::getChecked([&] { return emitError(getNameLoc()); },
std::forward<ParamsT>(params)...);
}
//===--------------------------------------------------------------------===//
// Token Parsing
//===--------------------------------------------------------------------===//
/// Parse a '->' token.
virtual ParseResult parseArrow() = 0;
/// Parse a '->' token if present
virtual ParseResult parseOptionalArrow() = 0;
/// Parse a '{' token.
virtual ParseResult parseLBrace() = 0;
/// Parse a '{' token if present
virtual ParseResult parseOptionalLBrace() = 0;
/// Parse a `}` token.
virtual ParseResult parseRBrace() = 0;
/// Parse a `}` token if present
virtual ParseResult parseOptionalRBrace() = 0;
/// Parse a `:` token.
virtual ParseResult parseColon() = 0;
/// Parse a `:` token if present.
virtual ParseResult parseOptionalColon() = 0;
/// Parse a `,` token.
virtual ParseResult parseComma() = 0;
/// Parse a `,` token if present.
virtual ParseResult parseOptionalComma() = 0;
/// Parse a `=` token.
virtual ParseResult parseEqual() = 0;
/// Parse a `=` token if present.
virtual ParseResult parseOptionalEqual() = 0;
/// Parse a quoted string token.
ParseResult parseString(std::string *string) {
auto loc = getCurrentLocation();
if (parseOptionalString(string))
return emitError(loc, "expected string");
return success();
}
/// Parse a quoted string token if present.
virtual ParseResult parseOptionalString(std::string *string) = 0;
/// Parse a given keyword.
ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
auto loc = getCurrentLocation();
if (parseOptionalKeyword(keyword))
return emitError(loc, "expected '") << keyword << "'" << msg;
return success();
}
/// Parse a keyword into 'keyword'.
ParseResult parseKeyword(StringRef *keyword) {
auto loc = getCurrentLocation();
if (parseOptionalKeyword(keyword))
return emitError(loc, "expected valid keyword");
return success();
}
/// Parse the given keyword if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
/// Parse a keyword, if present, into 'keyword'.
virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
/// Parse a '<' token.
virtual ParseResult parseLess() = 0;
/// Parse a `<` token if present.
virtual ParseResult parseOptionalLess() = 0;
/// Parse a '>' token.
virtual ParseResult parseGreater() = 0;
/// Parse a `>` token if present.
virtual ParseResult parseOptionalGreater() = 0;
/// Parse a `(` token.
virtual ParseResult parseLParen() = 0;
/// Parse a `(` token if present.
virtual ParseResult parseOptionalLParen() = 0;
/// Parse a `)` token.
virtual ParseResult parseRParen() = 0;
/// Parse a `)` token if present.
virtual ParseResult parseOptionalRParen() = 0;
/// Parse a `[` token.
virtual ParseResult parseLSquare() = 0;
/// Parse a `[` token if present.
virtual ParseResult parseOptionalLSquare() = 0;
/// Parse a `]` token.
virtual ParseResult parseRSquare() = 0;
/// Parse a `]` token if present.
virtual ParseResult parseOptionalRSquare() = 0;
/// Parse a `...` token if present;
virtual ParseResult parseOptionalEllipsis() = 0;
/// Parse a `?` token.
virtual ParseResult parseOptionalQuestion() = 0;
/// Parse a `*` token.
virtual ParseResult parseOptionalStar() = 0;
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
/// Parse an arbitrary attribute and return it in result.
virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
/// Parse an attribute of a specific kind and type.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, Type type = {}) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of attribute.
Attribute attr;
if (parseAttribute(attr, type))
return failure();
// Check for the right kind of attribute.
result = attr.dyn_cast<AttrType>();
if (!result)
return emitError(loc, "invalid kind of attribute specified");
return success();
}
/// Parse an affine map instance into 'map'.
virtual ParseResult parseAffineMap(AffineMap &map) = 0;
/// Parse an integer set instance into 'set'.
virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
/// Parse a type.
virtual ParseResult parseType(Type &result) = 0;
/// Parse a type of a specific kind, e.g. a FunctionType.
template <typename TypeType>
ParseResult parseType(TypeType &result) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of type.
Type type;
if (parseType(type))
return failure();
// Check for the right kind of attribute.
result = type.dyn_cast<TypeType>();
if (!result)
return emitError(loc, "invalid kind of type specified");
return success();
}
/// Parse a type if present.
virtual OptionalParseResult parseOptionalType(Type &result) = 0;
/// Parse a 'x' separated dimension list. This populates the dimension list,
/// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
/// `?` otherwise.
///
/// dimension-list ::= (dimension `x`)*
/// dimension ::= `?` | integer
///
/// When `allowDynamic` is not set, this is used to parse:
///
/// static-dimension-list ::= (integer `x`)*
virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic = true) = 0;
/// Parse an 'x' token in a dimension list, handling the case where the x is
/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
/// next token.
virtual ParseResult parseXInDimensionList() = 0;
};
} // end namespace mlir

View File

@ -24,17 +24,179 @@ namespace mlir {
class Builder;
//===----------------------------------------------------------------------===//
// AsmPrinter
//===----------------------------------------------------------------------===//
/// This base class exposes generic asm printer hooks, usable across the various
/// derived printers.
class AsmPrinter {
public:
/// This class contains the internal default implementation of the base
/// printer methods.
class Impl;
/// Initialize the printer with the given internal implementation.
AsmPrinter(Impl &impl) : impl(&impl) {}
virtual ~AsmPrinter();
/// Return the raw output stream used by this printer.
virtual raw_ostream &getStream() const;
/// Print the given floating point value in a stabilized form that can be
/// roundtripped through the IR. This is the companion to the 'parseFloat'
/// hook on the AsmParser.
virtual void printFloat(const APFloat &value);
virtual void printType(Type type);
virtual void printAttribute(Attribute attr);
/// Print the given attribute without its type. The corresponding parser must
/// provide a valid type for the attribute.
virtual void printAttributeWithoutType(Attribute attr);
/// Print the given string as a symbol reference, i.e. a form representable by
/// a SymbolRefAttr. A symbol reference is represented as a string prefixed
/// with '@'. The reference is surrounded with ""'s and escaped if it has any
/// special or non-printable characters in it.
virtual void printSymbolName(StringRef symbolRef);
/// Print an optional arrow followed by a type list.
template <typename TypeRange>
void printOptionalArrowTypeList(TypeRange &&types) {
if (types.begin() != types.end())
printArrowTypeList(types);
}
template <typename TypeRange>
void printArrowTypeList(TypeRange &&types) {
auto &os = getStream() << " -> ";
bool wrapped = !llvm::hasSingleElement(types) ||
(*types.begin()).template isa<FunctionType>();
if (wrapped)
os << '(';
llvm::interleaveComma(types, *this);
if (wrapped)
os << ')';
}
/// Print the two given type ranges in a functional form.
template <typename InputRangeT, typename ResultRangeT>
void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
auto &os = getStream();
os << '(';
llvm::interleaveComma(inputs, *this);
os << ')';
printArrowTypeList(results);
}
protected:
/// Initialize the printer with no internal implementation. In this case, all
/// virtual methods of this class must be overriden.
AsmPrinter() : impl(nullptr) {}
private:
AsmPrinter(const AsmPrinter &) = delete;
void operator=(const AsmPrinter &) = delete;
/// The internal implementation of the printer.
Impl *impl;
};
template <typename AsmPrinterT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
operator<<(AsmPrinterT &p, Type type) {
p.printType(type);
return p;
}
template <typename AsmPrinterT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
operator<<(AsmPrinterT &p, Attribute attr) {
p.printAttribute(attr);
return p;
}
template <typename AsmPrinterT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
operator<<(AsmPrinterT &p, const APFloat &value) {
p.printFloat(value);
return p;
}
template <typename AsmPrinterT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
operator<<(AsmPrinterT &p, float value) {
return p << APFloat(value);
}
template <typename AsmPrinterT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
operator<<(AsmPrinterT &p, double value) {
return p << APFloat(value);
}
// Support printing anything that isn't convertible to one of the other
// streamable types, even if it isn't exactly one of them. For example, we want
// to print FunctionType with the Type version above, not have it match this.
template <
typename AsmPrinterT, typename T,
typename std::enable_if<!std::is_convertible<T &, Value &>::value &&
!std::is_convertible<T &, Type &>::value &&
!std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, ValueRange>::value &&
!std::is_convertible<T &, APFloat &>::value &&
!llvm::is_one_of<T, bool, float, double>::value,
T>::type * = nullptr>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
operator<<(AsmPrinterT &p, const T &other) {
p.getStream() << other;
return p;
}
template <typename AsmPrinterT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
operator<<(AsmPrinterT &p, bool value) {
return p << (value ? StringRef("true") : "false");
}
template <typename AsmPrinterT, typename ValueRangeT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
llvm::interleaveComma(types, p);
return p;
}
template <typename AsmPrinterT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
operator<<(AsmPrinterT &p, const TypeRange &types) {
llvm::interleaveComma(types, p);
return p;
}
template <typename AsmPrinterT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
operator<<(AsmPrinterT &p, ArrayRef<Type> types) {
llvm::interleaveComma(types, p);
return p;
}
//===----------------------------------------------------------------------===//
// OpAsmPrinter
//===----------------------------------------------------------------------===//
/// This is a pure-virtual base class that exposes the asmprinter hooks
/// necessary to implement a custom print() method.
class OpAsmPrinter {
class OpAsmPrinter : public AsmPrinter {
public:
OpAsmPrinter() {}
virtual ~OpAsmPrinter();
virtual raw_ostream &getStream() const = 0;
using AsmPrinter::AsmPrinter;
~OpAsmPrinter() override;
/// Print a newline and indent the printer to the start of the current
/// operation.
@ -70,12 +232,6 @@ public:
printOperand(*it);
}
}
virtual void printType(Type type) = 0;
virtual void printAttribute(Attribute attr) = 0;
/// Print the given attribute without its type. The corresponding parser must
/// provide a valid type for the attribute.
virtual void printAttributeWithoutType(Attribute attr) = 0;
/// Print the given successor.
virtual void printSuccessor(Block *successor) = 0;
@ -131,47 +287,9 @@ public:
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
ValueRange symOperands) = 0;
/// Print an optional arrow followed by a type list.
template <typename TypeRange>
void printOptionalArrowTypeList(TypeRange &&types) {
if (types.begin() != types.end())
printArrowTypeList(types);
}
template <typename TypeRange>
void printArrowTypeList(TypeRange &&types) {
auto &os = getStream() << " -> ";
bool wrapped = !llvm::hasSingleElement(types) ||
(*types.begin()).template isa<FunctionType>();
if (wrapped)
os << '(';
llvm::interleaveComma(types, *this);
if (wrapped)
os << ')';
}
/// Print the complete type of an operation in functional form.
void printFunctionalType(Operation *op);
/// Print the two given type ranges in a functional form.
template <typename InputRangeT, typename ResultRangeT>
void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
auto &os = getStream();
os << '(';
llvm::interleaveComma(inputs, *this);
os << ')';
printArrowTypeList(results);
}
/// Print the given string as a symbol reference, i.e. a form representable by
/// a SymbolRefAttr. A symbol reference is represented as a string prefixed
/// with '@'. The reference is surrounded with ""'s and escaped if it has any
/// special or non-printable characters in it.
virtual void printSymbolName(StringRef symbolRef) = 0;
private:
OpAsmPrinter(const OpAsmPrinter &) = delete;
void operator=(const OpAsmPrinter &) = delete;
using AsmPrinter::printFunctionalType;
};
// Make the implementations convenient to use.
@ -189,77 +307,28 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
p.printType(type);
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
p.printAttribute(attr);
return p;
}
// Support printing anything that isn't convertible to one of the above types,
// even if it isn't exactly one of them. For example, we want to print
// FunctionType with the Type version above, not have it match this.
template <typename T, typename std::enable_if<
!std::is_convertible<T &, Value &>::value &&
!std::is_convertible<T &, Type &>::value &&
!std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, ValueRange>::value &&
!llvm::is_one_of<T, bool>::value,
T>::type * = nullptr>
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
p.getStream() << other;
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) {
return p << (value ? StringRef("true") : "false");
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
p.printSuccessor(value);
return p;
}
template <typename ValueRangeT>
inline OpAsmPrinter &operator<<(OpAsmPrinter &p,
const ValueTypeRange<ValueRangeT> &types) {
llvm::interleaveComma(types, p);
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) {
llvm::interleaveComma(types, p);
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
llvm::interleaveComma(types, p);
return p;
}
//===----------------------------------------------------------------------===//
// OpAsmParser
// AsmParser
//===----------------------------------------------------------------------===//
/// The OpAsmParser has methods for interacting with the asm parser: parsing
/// things from it, emitting errors etc. It has an intentionally high-level API
/// that is designed to reduce/constrain syntax innovation in individual
/// operations.
///
/// For example, consider an op like this:
///
/// %x = load %p[%1, %2] : memref<...>
///
/// The "%x = load" tokens are already parsed and therefore invisible to the
/// custom op parser. This can be supported by calling `parseOperandList` to
/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
/// parse the indices, then calling `parseColonTypeList` to parse the result
/// type.
///
class OpAsmParser {
/// This base class exposes generic asm parser hooks, usable across the various
/// derived parsers.
class AsmParser {
public:
virtual ~OpAsmParser();
AsmParser() = default;
virtual ~AsmParser();
/// Return the location of the original name token.
virtual llvm::SMLoc getNameLoc() const = 0;
//===--------------------------------------------------------------------===//
// Utilities
//===--------------------------------------------------------------------===//
/// Emit a diagnostic at the specified location and return failure.
virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
@ -277,44 +346,11 @@ public:
return success();
}
/// Return the name of the specified result in the specified syntax, as well
/// as the sub-element in the name. It returns an empty string and ~0U for
/// invalid result numbers. For example, in this operation:
///
/// %x, %y:2, %z = foo.op
///
/// getResultName(0) == {"x", 0 }
/// getResultName(1) == {"y", 0 }
/// getResultName(2) == {"y", 1 }
/// getResultName(3) == {"z", 0 }
/// getResultName(4) == {"", ~0U }
virtual std::pair<StringRef, unsigned>
getResultName(unsigned resultNo) const = 0;
/// Return the number of declared SSA results. This returns 4 for the foo.op
/// example in the comment for `getResultName`.
virtual size_t getNumResults() const = 0;
/// Return the location of the original name token.
virtual llvm::SMLoc getNameLoc() const = 0;
/// Re-encode the given source location as an MLIR location and return it.
/// Note: This method should only be used when a `Location` is necessary, as
/// the encoding process is not efficient.
virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
// These methods emit an error and return failure or success. This allows
// these to be chained together into a linear sequence of || expressions in
// many cases.
/// Parse an operation in its generic form.
/// The parsed operation is parsed in the current context and inserted in the
/// provided block and insertion point. The results produced by this operation
/// aren't mapped to any named value in the parser. Returns nullptr on
/// failure.
virtual Operation *parseGenericOperation(Block *insertBlock,
Block::iterator insertPt) = 0;
//===--------------------------------------------------------------------===//
// Token Parsing
//===--------------------------------------------------------------------===//
@ -385,6 +421,17 @@ public:
/// Parse a '*' token if present.
virtual ParseResult parseOptionalStar() = 0;
/// Parse a quoted string token.
ParseResult parseString(std::string *string) {
auto loc = getCurrentLocation();
if (parseOptionalString(string))
return emitError(loc, "expected string");
return success();
}
/// Parse a quoted string token if present.
virtual ParseResult parseOptionalString(std::string *string) = 0;
/// Parse a given keyword.
ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
auto loc = getCurrentLocation();
@ -440,6 +487,9 @@ public:
/// Parse a `...` token if present;
virtual ParseResult parseOptionalEllipsis() = 0;
/// Parse a floating point value from the stream.
virtual ParseResult parseFloat(double &result) = 0;
/// Parse an integer value from the stream.
template <typename IntT>
ParseResult parseInteger(IntT &result) {
@ -514,6 +564,27 @@ public:
return parseCommaSeparatedList(Delimiter::None, parseElementFn);
}
//===--------------------------------------------------------------------===//
// Attribute/Type Parsing
//===--------------------------------------------------------------------===//
/// Invoke the `getChecked` method of the given Attribute or Type class, using
/// the provided location to emit errors in the case of failure. Note that
/// unlike `OpBuilder::getType`, this method does not implicitly insert a
/// context parameter.
template <typename T, typename... ParamsT>
T getChecked(llvm::SMLoc loc, ParamsT &&... params) {
return T::getChecked([&] { return emitError(loc); },
std::forward<ParamsT>(params)...);
}
/// A variant of `getChecked` that uses the result of `getNameLoc` to emit
/// errors.
template <typename T, typename... ParamsT>
T getChecked(ParamsT &&... params) {
return T::getChecked([&] { return emitError(getNameLoc()); },
std::forward<ParamsT>(params)...);
}
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
@ -634,6 +705,180 @@ public:
virtual ParseResult
parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
/// Parse a type.
virtual ParseResult parseType(Type &result) = 0;
/// Parse an optional type.
virtual OptionalParseResult parseOptionalType(Type &result) = 0;
/// Parse a type of a specific type.
template <typename TypeT>
ParseResult parseType(TypeT &result) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of type.
Type type;
if (parseType(type))
return failure();
// Check for the right kind of attribute.
result = type.dyn_cast<TypeT>();
if (!result)
return emitError(loc, "invalid kind of type specified");
return success();
}
/// Parse a type list.
ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
do {
Type type;
if (parseType(type))
return failure();
result.push_back(type);
} while (succeeded(parseOptionalComma()));
return success();
}
/// Parse an arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse an optional arrow followed by a type list.
virtual ParseResult
parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a colon followed by a type.
virtual ParseResult parseColonType(Type &result) = 0;
/// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
template <typename TypeType>
ParseResult parseColonType(TypeType &result) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of type.
Type type;
if (parseColonType(type))
return failure();
// Check for the right kind of attribute.
result = type.dyn_cast<TypeType>();
if (!result)
return emitError(loc, "invalid kind of type specified");
return success();
}
/// Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse an optional colon followed by a type list, which if present must
/// have at least one type.
virtual ParseResult
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a keyword followed by a type.
ParseResult parseKeywordType(const char *keyword, Type &result) {
return failure(parseKeyword(keyword) || parseType(result));
}
/// Add the specified type to the end of the specified type list and return
/// success. This is a helper designed to allow parse methods to be simple
/// and chain through || operators.
ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
result.push_back(type);
return success();
}
/// Add the specified types to the end of the specified type list and return
/// success. This is a helper designed to allow parse methods to be simple
/// and chain through || operators.
ParseResult addTypesToList(ArrayRef<Type> types,
SmallVectorImpl<Type> &result) {
result.append(types.begin(), types.end());
return success();
}
/// Parse a 'x' separated dimension list. This populates the dimension list,
/// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
/// `?` otherwise.
///
/// dimension-list ::= (dimension `x`)*
/// dimension ::= `?` | integer
///
/// When `allowDynamic` is not set, this is used to parse:
///
/// static-dimension-list ::= (integer `x`)*
virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic = true) = 0;
/// Parse an 'x' token in a dimension list, handling the case where the x is
/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
/// next token.
virtual ParseResult parseXInDimensionList() = 0;
private:
AsmParser(const AsmParser &) = delete;
void operator=(const AsmParser &) = delete;
};
//===----------------------------------------------------------------------===//
// OpAsmParser
//===----------------------------------------------------------------------===//
/// The OpAsmParser has methods for interacting with the asm parser: parsing
/// things from it, emitting errors etc. It has an intentionally high-level API
/// that is designed to reduce/constrain syntax innovation in individual
/// operations.
///
/// For example, consider an op like this:
///
/// %x = load %p[%1, %2] : memref<...>
///
/// The "%x = load" tokens are already parsed and therefore invisible to the
/// custom op parser. This can be supported by calling `parseOperandList` to
/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
/// parse the indices, then calling `parseColonTypeList` to parse the result
/// type.
///
class OpAsmParser : public AsmParser {
public:
using AsmParser::AsmParser;
~OpAsmParser() override;
/// Return the name of the specified result in the specified syntax, as well
/// as the sub-element in the name. It returns an empty string and ~0U for
/// invalid result numbers. For example, in this operation:
///
/// %x, %y:2, %z = foo.op
///
/// getResultName(0) == {"x", 0 }
/// getResultName(1) == {"y", 0 }
/// getResultName(2) == {"y", 1 }
/// getResultName(3) == {"z", 0 }
/// getResultName(4) == {"", ~0U }
virtual std::pair<StringRef, unsigned>
getResultName(unsigned resultNo) const = 0;
/// Return the number of declared SSA results. This returns 4 for the foo.op
/// example in the comment for `getResultName`.
virtual size_t getNumResults() const = 0;
// These methods emit an error and return failure or success. This allows
// these to be chained together into a linear sequence of || expressions in
// many cases.
/// Parse an operation in its generic form.
/// The parsed operation is parsed in the current context and inserted in the
/// provided block and insertion point. The results produced by this operation
/// aren't mapped to any named value in the parser. Returns nullptr on
/// failure.
virtual Operation *parseGenericOperation(Block *insertBlock,
Block::iterator insertPt) = 0;
//===--------------------------------------------------------------------===//
// Operand Parsing
//===--------------------------------------------------------------------===//
@ -813,77 +1058,6 @@ public:
// Type Parsing
//===--------------------------------------------------------------------===//
/// Parse a type.
virtual ParseResult parseType(Type &result) = 0;
/// Parse an optional type.
virtual OptionalParseResult parseOptionalType(Type &result) = 0;
/// Parse a type of a specific type.
template <typename TypeT>
ParseResult parseType(TypeT &result) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of type.
Type type;
if (parseType(type))
return failure();
// Check for the right kind of attribute.
result = type.dyn_cast<TypeT>();
if (!result)
return emitError(loc, "invalid kind of type specified");
return success();
}
/// Parse a type list.
ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
do {
Type type;
if (parseType(type))
return failure();
result.push_back(type);
} while (succeeded(parseOptionalComma()));
return success();
}
/// Parse an arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse an optional arrow followed by a type list.
virtual ParseResult
parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a colon followed by a type.
virtual ParseResult parseColonType(Type &result) = 0;
/// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
template <typename TypeType>
ParseResult parseColonType(TypeType &result) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of type.
Type type;
if (parseColonType(type))
return failure();
// Check for the right kind of attribute.
result = type.dyn_cast<TypeType>();
if (!result)
return emitError(loc, "invalid kind of type specified");
return success();
}
/// Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse an optional colon followed by a type list, which if present must
/// have at least one type.
virtual ParseResult
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a list of assignments of the form
/// (%x1 = %y1, %x2 = %y2, ...)
ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
@ -914,27 +1088,6 @@ public:
parseOptionalAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
SmallVectorImpl<OperandType> &rhs,
SmallVectorImpl<Type> &types) = 0;
/// Parse a keyword followed by a type.
ParseResult parseKeywordType(const char *keyword, Type &result) {
return failure(parseKeyword(keyword) || parseType(result));
}
/// Add the specified type to the end of the specified type list and return
/// success. This is a helper designed to allow parse methods to be simple
/// and chain through || operators.
ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
result.push_back(type);
return success();
}
/// Add the specified types to the end of the specified type list and return
/// success. This is a helper designed to allow parse methods to be simple
/// and chain through || operators.
ParseResult addTypesToList(ArrayRef<Type> types,
SmallVectorImpl<Type> &result) {
result.append(types.begin(), types.end());
return success();
}
private:
/// Parse either an operand list or a region argument list depending on

View File

@ -52,6 +52,18 @@ void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
void OperationName::dump() const { print(llvm::errs()); }
//===--------------------------------------------------------------------===//
// AsmParser
//===--------------------------------------------------------------------===//
AsmParser::~AsmParser() {}
DialectAsmParser::~DialectAsmParser() {}
OpAsmParser::~OpAsmParser() {}
//===--------------------------------------------------------------------===//
// DialectAsmPrinter
//===--------------------------------------------------------------------===//
DialectAsmPrinter::~DialectAsmPrinter() {}
//===--------------------------------------------------------------------===//
@ -250,12 +262,12 @@ namespace {
struct NewLineCounter {
unsigned curLine = 1;
};
} // end anonymous namespace
static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
++newLine.curLine;
return os << '\n';
}
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// AliasInitializer
@ -492,6 +504,7 @@ private:
/// The following are hooks of `OpAsmPrinter` that are not necessary for
/// determining potential aliases.
void printFloat(const APFloat &value) override {}
void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
void printNewline() override {}
@ -1202,18 +1215,17 @@ AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
AsmState::~AsmState() {}
//===----------------------------------------------------------------------===//
// ModulePrinter
// AsmPrinter::Impl
//===----------------------------------------------------------------------===//
namespace {
class ModulePrinter {
namespace mlir {
class AsmPrinter::Impl {
public:
ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
AsmStateImpl *state = nullptr)
Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None,
AsmStateImpl *state = nullptr)
: os(os), printerFlags(flags), state(state) {}
explicit ModulePrinter(ModulePrinter &printer)
: os(printer.os), printerFlags(printer.printerFlags),
state(printer.state) {}
explicit Impl(Impl &other)
: Impl(other.os, other.printerFlags, other.state) {}
/// Returns the output stream of the printer.
raw_ostream &getStream() { return os; }
@ -1298,9 +1310,9 @@ protected:
/// A tracker for the number of new lines emitted during printing.
NewLineCounter newLine;
};
} // end anonymous namespace
} // namespace mlir
void ModulePrinter::printTrailingLocation(Location loc, bool allowAlias) {
void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
// Check to see if we are printing debug information.
if (!printerFlags.shouldPrintDebugInfo())
return;
@ -1309,7 +1321,7 @@ void ModulePrinter::printTrailingLocation(Location loc, bool allowAlias) {
printLocation(loc, /*allowAlias=*/allowAlias);
}
void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
TypeSwitch<LocationAttr>(loc)
.Case<OpaqueLoc>([&](OpaqueLoc loc) {
printLocationInternal(loc.getFallbackLocation(), pretty);
@ -1430,7 +1442,7 @@ static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
os << str;
}
void ModulePrinter::printLocation(LocationAttr loc, bool allowAlias) {
void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
if (printerFlags.shouldPrintDebugInfoPrettyForm())
return printLocationInternal(loc, /*pretty=*/true);
@ -1578,8 +1590,8 @@ static void printElidedElementsAttr(raw_ostream &os) {
os << R"(opaque<"_", "0xDEADBEEF">)";
}
void ModulePrinter::printAttribute(Attribute attr,
AttrTypeElision typeElision) {
void AsmPrinter::Impl::printAttribute(Attribute attr,
AttrTypeElision typeElision) {
if (!attr) {
os << "<<NULL ATTRIBUTE>>";
return;
@ -1780,8 +1792,8 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
os << ']';
}
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
bool allowHex) {
void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
bool allowHex) {
if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
return printDenseStringElementsAttr(stringAttr);
@ -1789,8 +1801,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
allowHex);
}
void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
bool allowHex) {
void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
DenseIntOrFPElementsAttr attr, bool allowHex) {
auto type = attr.getType();
auto elementType = type.getElementType();
@ -1860,7 +1872,8 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
}
}
void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
void AsmPrinter::Impl::printDenseStringElementsAttr(
DenseStringElementsAttr attr) {
ArrayRef<StringRef> data = attr.getRawStringData();
auto printFn = [&](unsigned index) {
os << "\"";
@ -1870,7 +1883,7 @@ void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
}
void ModulePrinter::printType(Type type) {
void AsmPrinter::Impl::printType(Type type) {
if (!type) {
os << "<<NULL TYPE>>";
return;
@ -1986,9 +1999,9 @@ void ModulePrinter::printType(Type type) {
.Default([&](Type type) { return printDialectType(type); });
}
void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs,
bool withKeyword) {
void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs,
bool withKeyword) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
@ -2020,7 +2033,7 @@ void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
printFilteredAttributesFn(filteredAttrs);
}
void ModulePrinter::printNamedAttribute(NamedAttribute attr) {
void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
if (isBareIdentifier(attr.first)) {
os << attr.first;
} else {
@ -2037,81 +2050,82 @@ void ModulePrinter::printNamedAttribute(NamedAttribute attr) {
printAttribute(attr.second);
}
//===----------------------------------------------------------------------===//
// CustomDialectAsmPrinter
//===----------------------------------------------------------------------===//
namespace {
/// This class provides the main specialization of the DialectAsmPrinter that is
/// used to provide support for print attributes and types. This hooks allows
/// for dialects to hook into the main ModulePrinter.
struct CustomDialectAsmPrinter : public DialectAsmPrinter {
public:
CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
~CustomDialectAsmPrinter() override {}
raw_ostream &getStream() const override { return printer.getStream(); }
/// Print the given attribute to the stream.
void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
/// Print the given attribute without its type. The corresponding parser must
/// provide a valid type for the attribute.
void printAttributeWithoutType(Attribute attr) override {
printer.printAttribute(attr, ModulePrinter::AttrTypeElision::Must);
}
/// Print the given floating point value in a stablized form.
void printFloat(const APFloat &value) override {
printFloatValue(value, getStream());
}
/// Print the given type to the stream.
void printType(Type type) override { printer.printType(type); }
/// The main module printer.
ModulePrinter &printer;
};
} // end anonymous namespace
void ModulePrinter::printDialectAttribute(Attribute attr) {
void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
auto &dialect = attr.getDialect();
// Ask the dialect to serialize the attribute to a string.
std::string attrName;
{
llvm::raw_string_ostream attrNameStr(attrName);
ModulePrinter subPrinter(attrNameStr, printerFlags, state);
CustomDialectAsmPrinter printer(subPrinter);
Impl subPrinter(attrNameStr, printerFlags, state);
DialectAsmPrinter printer(subPrinter);
dialect.printAttribute(attr, printer);
}
printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
}
void ModulePrinter::printDialectType(Type type) {
void AsmPrinter::Impl::printDialectType(Type type) {
auto &dialect = type.getDialect();
// Ask the dialect to serialize the type to a string.
std::string typeName;
{
llvm::raw_string_ostream typeNameStr(typeName);
ModulePrinter subPrinter(typeNameStr, printerFlags, state);
CustomDialectAsmPrinter printer(subPrinter);
Impl subPrinter(typeNameStr, printerFlags, state);
DialectAsmPrinter printer(subPrinter);
dialect.printType(type, printer);
}
printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
}
//===--------------------------------------------------------------------===//
// AsmPrinter
//===--------------------------------------------------------------------===//
AsmPrinter::~AsmPrinter() {}
raw_ostream &AsmPrinter::getStream() const {
assert(impl && "expected AsmPrinter::getStream to be overriden");
return impl->getStream();
}
/// Print the given floating point value in a stablized form.
void AsmPrinter::printFloat(const APFloat &value) {
assert(impl && "expected AsmPrinter::printFloat to be overriden");
printFloatValue(value, impl->getStream());
}
void AsmPrinter::printType(Type type) {
assert(impl && "expected AsmPrinter::printType to be overriden");
impl->printType(type);
}
void AsmPrinter::printAttribute(Attribute attr) {
assert(impl && "expected AsmPrinter::printAttribute to be overriden");
impl->printAttribute(attr);
}
void AsmPrinter::printAttributeWithoutType(Attribute attr) {
assert(impl &&
"expected AsmPrinter::printAttributeWithoutType to be overriden");
impl->printAttribute(attr, Impl::AttrTypeElision::Must);
}
void AsmPrinter::printSymbolName(StringRef symbolRef) {
assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
::printSymbolReference(symbolRef, impl->getStream());
}
//===----------------------------------------------------------------------===//
// Affine expressions and maps
//===----------------------------------------------------------------------===//
void ModulePrinter::printAffineExpr(
void AsmPrinter::Impl::printAffineExpr(
AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
}
void ModulePrinter::printAffineExprInternal(
void AsmPrinter::Impl::printAffineExprInternal(
AffineExpr expr, BindingStrength enclosingTightness,
function_ref<void(unsigned, bool)> printValueName) {
const char *binopSpelling = nullptr;
@ -2244,12 +2258,12 @@ void ModulePrinter::printAffineExprInternal(
os << ')';
}
void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
printAffineExprInternal(expr, BindingStrength::Weak);
isEq ? os << " == 0" : os << " >= 0";
}
void ModulePrinter::printAffineMap(AffineMap map) {
void AsmPrinter::Impl::printAffineMap(AffineMap map) {
// Dimension identifiers.
os << '(';
for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
@ -2275,7 +2289,7 @@ void ModulePrinter::printAffineMap(AffineMap map) {
os << ')';
}
void ModulePrinter::printIntegerSet(IntegerSet set) {
void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
// Dimension identifiers.
os << '(';
for (unsigned i = 1; i < set.getNumDims(); ++i)
@ -2313,11 +2327,14 @@ void ModulePrinter::printIntegerSet(IntegerSet set) {
namespace {
/// This class contains the logic for printing operations, regions, and blocks.
class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
public:
using Impl = AsmPrinter::Impl;
using Impl::printType;
explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
AsmStateImpl &state)
: ModulePrinter(os, flags, &state) {}
: Impl(os, flags, &state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
/// Print the given top-level operation.
void printTopLevelOperation(Operation *op);
@ -2346,9 +2363,6 @@ public:
// OpAsmPrinter methods
//===--------------------------------------------------------------------===//
/// Return the current stream of the printer.
raw_ostream &getStream() const override { return os; }
/// Print a newline and indent the printer to the start of the current
/// operation.
void printNewline() override {
@ -2356,20 +2370,6 @@ public:
os.indent(currentIndent);
}
/// Print the given type.
void printType(Type type) override { ModulePrinter::printType(type); }
/// Print the given attribute.
void printAttribute(Attribute attr) override {
ModulePrinter::printAttribute(attr);
}
/// Print the given attribute without its type. The corresponding parser must
/// provide a valid type for the attribute.
void printAttributeWithoutType(Attribute attr) override {
ModulePrinter::printAttribute(attr, AttrTypeElision::Must);
}
/// Print a block argument in the usual format of:
/// %ssaName : type {attr1=42} loc("here")
/// where location printing is controlled by the standard internal option.
@ -2388,13 +2388,13 @@ public:
/// Print an optional attribute dictionary with a given set of elided values.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
Impl::printOptionalAttrDict(attrs, elidedAttrs);
}
void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs,
/*withKeyword=*/true);
Impl::printOptionalAttrDict(attrs, elidedAttrs,
/*withKeyword=*/true);
}
/// Print the given successor.
@ -2427,11 +2427,6 @@ public:
void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
ValueRange symOperands) override;
/// Print the given string as a symbol reference.
void printSymbolName(StringRef symbolRef) override {
::printSymbolReference(symbolRef, os);
}
private:
// Contains the stack of default dialects to use when printing regions.
// A new dialect is pushed to the stack before parsing regions nested under an
@ -2732,7 +2727,7 @@ void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
//===----------------------------------------------------------------------===//
void Attribute::print(raw_ostream &os) const {
ModulePrinter(os).printAttribute(*this);
AsmPrinter::Impl(os).printAttribute(*this);
}
void Attribute::dump() const {
@ -2740,7 +2735,9 @@ void Attribute::dump() const {
llvm::errs() << "\n";
}
void Type::print(raw_ostream &os) const { ModulePrinter(os).printType(*this); }
void Type::print(raw_ostream &os) const {
AsmPrinter::Impl(os).printType(*this);
}
void Type::dump() const { print(llvm::errs()); }
@ -2759,7 +2756,7 @@ void AffineExpr::print(raw_ostream &os) const {
os << "<<NULL AFFINE EXPR>>";
return;
}
ModulePrinter(os).printAffineExpr(*this);
AsmPrinter::Impl(os).printAffineExpr(*this);
}
void AffineExpr::dump() const {
@ -2772,11 +2769,11 @@ void AffineMap::print(raw_ostream &os) const {
os << "<<NULL AFFINE MAP>>";
return;
}
ModulePrinter(os).printAffineMap(*this);
AsmPrinter::Impl(os).printAffineMap(*this);
}
void IntegerSet::print(raw_ostream &os) const {
ModulePrinter(os).printIntegerSet(*this);
AsmPrinter::Impl(os).printIntegerSet(*this);
}
void Value::print(raw_ostream &os) {

View File

@ -24,8 +24,6 @@
using namespace mlir;
using namespace detail;
DialectAsmParser::~DialectAsmParser() {}
//===----------------------------------------------------------------------===//
// DialectRegistry
//===----------------------------------------------------------------------===//

View File

@ -19,8 +19,6 @@
using namespace mlir;
OpAsmParser::~OpAsmParser() {}
//===----------------------------------------------------------------------===//
// OperationName
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,501 @@
//===- AsmParserImpl.h - MLIR AsmParserImpl Class ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_LIB_PARSER_ASMPARSERIMPL_H
#define MLIR_LIB_PARSER_ASMPARSERIMPL_H
#include "Parser.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Parser/AsmParserState.h"
namespace mlir {
namespace detail {
//===----------------------------------------------------------------------===//
// AsmParserImpl
//===----------------------------------------------------------------------===//
/// This class provides the implementation of the generic parser methods within
/// AsmParser.
template <typename BaseT>
class AsmParserImpl : public BaseT {
public:
AsmParserImpl(llvm::SMLoc nameLoc, Parser &parser)
: nameLoc(nameLoc), parser(parser) {}
~AsmParserImpl() override {}
/// Return the location of the original name token.
llvm::SMLoc getNameLoc() const override { return nameLoc; }
//===--------------------------------------------------------------------===//
// Utilities
//===--------------------------------------------------------------------===//
/// Return if any errors were emitted during parsing.
bool didEmitError() const { return emittedError; }
/// Emit a diagnostic at the specified location and return failure.
InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
emittedError = true;
return parser.emitError(loc, message);
}
/// Return a builder which provides useful access to MLIRContext, global
/// objects like types and attributes.
Builder &getBuilder() const override { return parser.builder; }
/// Get the location of the next token and store it into the argument. This
/// always succeeds.
llvm::SMLoc getCurrentLocation() override {
return parser.getToken().getLoc();
}
/// Re-encode the given source location as an MLIR location and return it.
Location getEncodedSourceLoc(llvm::SMLoc loc) override {
return parser.getEncodedSourceLocation(loc);
}
//===--------------------------------------------------------------------===//
// Token Parsing
//===--------------------------------------------------------------------===//
using Delimiter = AsmParser::Delimiter;
/// Parse a `->` token.
ParseResult parseArrow() override {
return parser.parseToken(Token::arrow, "expected '->'");
}
/// Parses a `->` if present.
ParseResult parseOptionalArrow() override {
return success(parser.consumeIf(Token::arrow));
}
/// Parse a '{' token.
ParseResult parseLBrace() override {
return parser.parseToken(Token::l_brace, "expected '{'");
}
/// Parse a '{' token if present
ParseResult parseOptionalLBrace() override {
return success(parser.consumeIf(Token::l_brace));
}
/// Parse a `}` token.
ParseResult parseRBrace() override {
return parser.parseToken(Token::r_brace, "expected '}'");
}
/// Parse a `}` token if present
ParseResult parseOptionalRBrace() override {
return success(parser.consumeIf(Token::r_brace));
}
/// Parse a `:` token.
ParseResult parseColon() override {
return parser.parseToken(Token::colon, "expected ':'");
}
/// Parse a `:` token if present.
ParseResult parseOptionalColon() override {
return success(parser.consumeIf(Token::colon));
}
/// Parse a `,` token.
ParseResult parseComma() override {
return parser.parseToken(Token::comma, "expected ','");
}
/// Parse a `,` token if present.
ParseResult parseOptionalComma() override {
return success(parser.consumeIf(Token::comma));
}
/// Parses a `...` if present.
ParseResult parseOptionalEllipsis() override {
return success(parser.consumeIf(Token::ellipsis));
}
/// Parse a `=` token.
ParseResult parseEqual() override {
return parser.parseToken(Token::equal, "expected '='");
}
/// Parse a `=` token if present.
ParseResult parseOptionalEqual() override {
return success(parser.consumeIf(Token::equal));
}
/// Parse a '<' token.
ParseResult parseLess() override {
return parser.parseToken(Token::less, "expected '<'");
}
/// Parse a `<` token if present.
ParseResult parseOptionalLess() override {
return success(parser.consumeIf(Token::less));
}
/// Parse a '>' token.
ParseResult parseGreater() override {
return parser.parseToken(Token::greater, "expected '>'");
}
/// Parse a `>` token if present.
ParseResult parseOptionalGreater() override {
return success(parser.consumeIf(Token::greater));
}
/// Parse a `(` token.
ParseResult parseLParen() override {
return parser.parseToken(Token::l_paren, "expected '('");
}
/// Parses a '(' if present.
ParseResult parseOptionalLParen() override {
return success(parser.consumeIf(Token::l_paren));
}
/// Parse a `)` token.
ParseResult parseRParen() override {
return parser.parseToken(Token::r_paren, "expected ')'");
}
/// Parses a ')' if present.
ParseResult parseOptionalRParen() override {
return success(parser.consumeIf(Token::r_paren));
}
/// Parse a `[` token.
ParseResult parseLSquare() override {
return parser.parseToken(Token::l_square, "expected '['");
}
/// Parses a '[' if present.
ParseResult parseOptionalLSquare() override {
return success(parser.consumeIf(Token::l_square));
}
/// Parse a `]` token.
ParseResult parseRSquare() override {
return parser.parseToken(Token::r_square, "expected ']'");
}
/// Parses a ']' if present.
ParseResult parseOptionalRSquare() override {
return success(parser.consumeIf(Token::r_square));
}
/// Parses a '?' token.
ParseResult parseQuestion() override {
return parser.parseToken(Token::question, "expected '?'");
}
/// Parses a '?' if present.
ParseResult parseOptionalQuestion() override {
return success(parser.consumeIf(Token::question));
}
/// Parses a '*' token.
ParseResult parseStar() override {
return parser.parseToken(Token::star, "expected '*'");
}
/// Parses a '*' if present.
ParseResult parseOptionalStar() override {
return success(parser.consumeIf(Token::star));
}
/// Parses a '+' token.
ParseResult parsePlus() override {
return parser.parseToken(Token::plus, "expected '+'");
}
/// Parses a '+' token if present.
ParseResult parseOptionalPlus() override {
return success(parser.consumeIf(Token::plus));
}
/// Parses a quoted string token if present.
ParseResult parseOptionalString(std::string *string) override {
if (!parser.getToken().is(Token::string))
return failure();
if (string)
*string = parser.getToken().getStringValue();
parser.consumeToken();
return success();
}
/// Returns true if the current token corresponds to a keyword.
bool isCurrentTokenAKeyword() const {
return parser.getToken().isAny(Token::bare_identifier, Token::inttype) ||
parser.getToken().isKeyword();
}
/// Parse the given keyword if present.
ParseResult parseOptionalKeyword(StringRef keyword) override {
// Check that the current token has the same spelling.
if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword)
return failure();
parser.consumeToken();
return success();
}
/// Parse a keyword, if present, into 'keyword'.
ParseResult parseOptionalKeyword(StringRef *keyword) override {
// Check that the current token is a keyword.
if (!isCurrentTokenAKeyword())
return failure();
*keyword = parser.getTokenSpelling();
parser.consumeToken();
return success();
}
/// Parse a keyword if it is one of the 'allowedKeywords'.
ParseResult
parseOptionalKeyword(StringRef *keyword,
ArrayRef<StringRef> allowedKeywords) override {
// Check that the current token is a keyword.
if (!isCurrentTokenAKeyword())
return failure();
StringRef currentKeyword = parser.getTokenSpelling();
if (llvm::is_contained(allowedKeywords, currentKeyword)) {
*keyword = currentKeyword;
parser.consumeToken();
return success();
}
return failure();
}
/// Parse a floating point value from the stream.
ParseResult parseFloat(double &result) override {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
llvm::SMLoc loc = curTok.getLoc();
// Check for a floating point value.
if (curTok.is(Token::floatliteral)) {
auto val = curTok.getFloatingPointValue();
if (!val.hasValue())
return emitError(loc, "floating point value too large");
parser.consumeToken(Token::floatliteral);
result = isNegative ? -*val : *val;
return success();
}
// Check for a hexadecimal float value.
if (curTok.is(Token::integer)) {
Optional<APFloat> apResult;
if (failed(parser.parseFloatFromIntegerLiteral(
apResult, curTok, isNegative, APFloat::IEEEdouble(),
/*typeSizeInBits=*/64)))
return failure();
parser.consumeToken(Token::integer);
result = apResult->convertToDouble();
return success();
}
return emitError(loc, "expected floating point literal");
}
/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalInteger(APInt &result) override {
return parser.parseOptionalInteger(result);
}
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
ParseResult parseCommaSeparatedList(Delimiter delimiter,
function_ref<ParseResult()> parseElt,
StringRef contextMessage) override {
return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
}
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
/// Parse an arbitrary attribute and return it in result.
ParseResult parseAttribute(Attribute &result, Type type) override {
result = parser.parseAttribute(type);
return success(static_cast<bool>(result));
}
/// Parse an optional attribute.
template <typename AttrT>
OptionalParseResult
parseOptionalAttributeAndAddToList(AttrT &result, Type type,
StringRef attrName, NamedAttrList &attrs) {
OptionalParseResult parseResult =
parser.parseOptionalAttribute(result, type);
if (parseResult.hasValue() && succeeded(*parseResult))
attrs.push_back(parser.builder.getNamedAttr(attrName, result));
return parseResult;
}
OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
OptionalParseResult parseOptionalAttribute(ArrayAttr &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
OptionalParseResult parseOptionalAttribute(StringAttr &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
/// Parse a named dictionary into 'result' if it is present.
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
if (parser.getToken().isNot(Token::l_brace))
return success();
return parser.parseAttributeDict(result);
}
/// Parse a named dictionary into 'result' if the `attributes` keyword is
/// present.
ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result) override {
if (failed(parseOptionalKeyword("attributes")))
return success();
return parser.parseAttributeDict(result);
}
/// Parse an affine map instance into 'map'.
ParseResult parseAffineMap(AffineMap &map) override {
return parser.parseAffineMapReference(map);
}
/// Parse an integer set instance into 'set'.
ParseResult printIntegerSet(IntegerSet &set) override {
return parser.parseIntegerSetReference(set);
}
//===--------------------------------------------------------------------===//
// Identifier Parsing
//===--------------------------------------------------------------------===//
/// Parse an optional @-identifier and store it (without the '@' symbol) in a
/// string attribute named 'attrName'.
ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName,
NamedAttrList &attrs) override {
Token atToken = parser.getToken();
if (atToken.isNot(Token::at_identifier))
return failure();
result = getBuilder().getStringAttr(atToken.getSymbolReference());
attrs.push_back(getBuilder().getNamedAttr(attrName, result));
parser.consumeToken();
// If we are populating the assembly parser state, record this as a symbol
// reference.
if (parser.getState().asmState) {
parser.getState().asmState->addUses(SymbolRefAttr::get(result),
atToken.getLocRange());
}
return success();
}
/// Parse a loc(...) specifier if present, filling in result if so.
ParseResult
parseOptionalLocationSpecifier(Optional<Location> &result) override {
// If there is a 'loc' we parse a trailing location.
if (!parser.consumeIf(Token::kw_loc))
return success();
LocationAttr directLoc;
if (parser.parseToken(Token::l_paren, "expected '(' in location") ||
parser.parseLocationInstance(directLoc) ||
parser.parseToken(Token::r_paren, "expected ')' in location"))
return failure();
result = directLoc;
return success();
}
//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
/// Parse a type.
ParseResult parseType(Type &result) override {
return failure(!(result = parser.parseType()));
}
/// Parse an optional type.
OptionalParseResult parseOptionalType(Type &result) override {
return parser.parseOptionalType(result);
}
/// Parse an arrow followed by a type list.
ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override {
if (parseArrow() || parser.parseFunctionResultTypes(result))
return failure();
return success();
}
/// Parse an optional arrow followed by a type list.
ParseResult
parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override {
if (!parser.consumeIf(Token::arrow))
return success();
return parser.parseFunctionResultTypes(result);
}
/// Parse a colon followed by a type.
ParseResult parseColonType(Type &result) override {
return failure(parser.parseToken(Token::colon, "expected ':'") ||
!(result = parser.parseType()));
}
/// Parse a colon followed by a type list, which must have at least one type.
ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
if (parser.parseToken(Token::colon, "expected ':'"))
return failure();
return parser.parseTypeListNoParens(result);
}
/// Parse an optional colon followed by a type list, which if present must
/// have at least one type.
ParseResult
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override {
if (!parser.consumeIf(Token::colon))
return success();
return parser.parseTypeListNoParens(result);
}
ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic) override {
return parser.parseDimensionListRanked(dimensions, allowDynamic);
}
ParseResult parseXInDimensionList() override {
return parser.parseXInDimensionList();
}
protected:
/// The source location of the dialect symbol.
llvm::SMLoc nameLoc;
/// The main parser.
Parser &parser;
/// A flag that indicates if any errors were emitted during parsing.
bool emittedError = false;
};
} // namespace detail
} // end namespace mlir
#endif // MLIR_LIB_PARSER_ASMPARSERIMPL_H

View File

@ -11,7 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "Parser.h"
#include "AsmParserImpl.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
@ -27,304 +27,20 @@ namespace {
/// This class provides the main implementation of the DialectAsmParser that
/// allows for dialects to parse attributes and types. This allows for dialect
/// hooking into the main MLIR parsing logic.
class CustomDialectAsmParser : public DialectAsmParser {
class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
public:
CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
: fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()),
parser(parser) {}
: AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
fullSpec(fullSpec) {}
~CustomDialectAsmParser() override {}
/// Emit a diagnostic at the specified location and return failure.
InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
return parser.emitError(loc, message);
}
/// Return a builder which provides useful access to MLIRContext, global
/// objects like types and attributes.
Builder &getBuilder() const override { return parser.builder; }
/// Get the location of the next token and store it into the argument. This
/// always succeeds.
llvm::SMLoc getCurrentLocation() override {
return parser.getToken().getLoc();
}
/// Return the location of the original name token.
llvm::SMLoc getNameLoc() const override { return nameLoc; }
/// Re-encode the given source location as an MLIR location and return it.
Location getEncodedSourceLoc(llvm::SMLoc loc) override {
return parser.getEncodedSourceLocation(loc);
}
/// Returns the full specification of the symbol being parsed. This allows
/// for using a separate parser if necessary.
StringRef getFullSymbolSpec() const override { return fullSpec; }
/// Parse a floating point value from the stream.
ParseResult parseFloat(double &result) override {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
llvm::SMLoc loc = curTok.getLoc();
// Check for a floating point value.
if (curTok.is(Token::floatliteral)) {
auto val = curTok.getFloatingPointValue();
if (!val.hasValue())
return emitError(loc, "floating point value too large");
parser.consumeToken(Token::floatliteral);
result = isNegative ? -*val : *val;
return success();
}
// Check for a hexadecimal float value.
if (curTok.is(Token::integer)) {
Optional<APFloat> apResult;
if (failed(parser.parseFloatFromIntegerLiteral(
apResult, curTok, isNegative, APFloat::IEEEdouble(),
/*typeSizeInBits=*/64)))
return failure();
parser.consumeToken(Token::integer);
result = apResult->convertToDouble();
return success();
}
return emitError(loc, "expected floating point literal");
}
/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalInteger(APInt &result) override {
return parser.parseOptionalInteger(result);
}
//===--------------------------------------------------------------------===//
// Token Parsing
//===--------------------------------------------------------------------===//
/// Parse a `->` token.
ParseResult parseArrow() override {
return parser.parseToken(Token::arrow, "expected '->'");
}
/// Parses a `->` if present.
ParseResult parseOptionalArrow() override {
return success(parser.consumeIf(Token::arrow));
}
/// Parse a '{' token.
ParseResult parseLBrace() override {
return parser.parseToken(Token::l_brace, "expected '{'");
}
/// Parse a '{' token if present
ParseResult parseOptionalLBrace() override {
return success(parser.consumeIf(Token::l_brace));
}
/// Parse a `}` token.
ParseResult parseRBrace() override {
return parser.parseToken(Token::r_brace, "expected '}'");
}
/// Parse a `}` token if present
ParseResult parseOptionalRBrace() override {
return success(parser.consumeIf(Token::r_brace));
}
/// Parse a `:` token.
ParseResult parseColon() override {
return parser.parseToken(Token::colon, "expected ':'");
}
/// Parse a `:` token if present.
ParseResult parseOptionalColon() override {
return success(parser.consumeIf(Token::colon));
}
/// Parse a `,` token.
ParseResult parseComma() override {
return parser.parseToken(Token::comma, "expected ','");
}
/// Parse a `,` token if present.
ParseResult parseOptionalComma() override {
return success(parser.consumeIf(Token::comma));
}
/// Parses a `...` if present.
ParseResult parseOptionalEllipsis() override {
return success(parser.consumeIf(Token::ellipsis));
}
/// Parse a `=` token.
ParseResult parseEqual() override {
return parser.parseToken(Token::equal, "expected '='");
}
/// Parse a `=` token if present.
ParseResult parseOptionalEqual() override {
return success(parser.consumeIf(Token::equal));
}
/// Parse a '<' token.
ParseResult parseLess() override {
return parser.parseToken(Token::less, "expected '<'");
}
/// Parse a `<` token if present.
ParseResult parseOptionalLess() override {
return success(parser.consumeIf(Token::less));
}
/// Parse a '>' token.
ParseResult parseGreater() override {
return parser.parseToken(Token::greater, "expected '>'");
}
/// Parse a `>` token if present.
ParseResult parseOptionalGreater() override {
return success(parser.consumeIf(Token::greater));
}
/// Parse a `(` token.
ParseResult parseLParen() override {
return parser.parseToken(Token::l_paren, "expected '('");
}
/// Parses a '(' if present.
ParseResult parseOptionalLParen() override {
return success(parser.consumeIf(Token::l_paren));
}
/// Parse a `)` token.
ParseResult parseRParen() override {
return parser.parseToken(Token::r_paren, "expected ')'");
}
/// Parses a ')' if present.
ParseResult parseOptionalRParen() override {
return success(parser.consumeIf(Token::r_paren));
}
/// Parse a `[` token.
ParseResult parseLSquare() override {
return parser.parseToken(Token::l_square, "expected '['");
}
/// Parses a '[' if present.
ParseResult parseOptionalLSquare() override {
return success(parser.consumeIf(Token::l_square));
}
/// Parse a `]` token.
ParseResult parseRSquare() override {
return parser.parseToken(Token::r_square, "expected ']'");
}
/// Parses a ']' if present.
ParseResult parseOptionalRSquare() override {
return success(parser.consumeIf(Token::r_square));
}
/// Parses a '?' if present.
ParseResult parseOptionalQuestion() override {
return success(parser.consumeIf(Token::question));
}
/// Parses a '*' if present.
ParseResult parseOptionalStar() override {
return success(parser.consumeIf(Token::star));
}
/// Parses a quoted string token if present.
ParseResult parseOptionalString(std::string *string) override {
if (!parser.getToken().is(Token::string))
return failure();
if (string)
*string = parser.getToken().getStringValue();
parser.consumeToken();
return success();
}
/// Returns true if the current token corresponds to a keyword.
bool isCurrentTokenAKeyword() const {
return parser.getToken().isAny(Token::bare_identifier, Token::inttype) ||
parser.getToken().isKeyword();
}
/// Parse the given keyword if present.
ParseResult parseOptionalKeyword(StringRef keyword) override {
// Check that the current token has the same spelling.
if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword)
return failure();
parser.consumeToken();
return success();
}
/// Parse a keyword, if present, into 'keyword'.
ParseResult parseOptionalKeyword(StringRef *keyword) override {
// Check that the current token is a keyword.
if (!isCurrentTokenAKeyword())
return failure();
*keyword = parser.getTokenSpelling();
parser.consumeToken();
return success();
}
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
/// Parse an arbitrary attribute and return it in result.
ParseResult parseAttribute(Attribute &result, Type type) override {
result = parser.parseAttribute(type);
return success(static_cast<bool>(result));
}
/// Parse an affine map instance into 'map'.
ParseResult parseAffineMap(AffineMap &map) override {
return parser.parseAffineMapReference(map);
}
/// Parse an integer set instance into 'set'.
ParseResult printIntegerSet(IntegerSet &set) override {
return parser.parseIntegerSetReference(set);
}
//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
ParseResult parseType(Type &result) override {
result = parser.parseType();
return success(static_cast<bool>(result));
}
ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic) override {
return parser.parseDimensionListRanked(dimensions, allowDynamic);
}
ParseResult parseXInDimensionList() override {
return parser.parseXInDimensionList();
}
OptionalParseResult parseOptionalType(Type &result) override {
return parser.parseOptionalType(result);
}
private:
/// The full symbol specification.
StringRef fullSpec;
/// The source location of the dialect symbol.
SMLoc nameLoc;
/// The main parser.
Parser &parser;
};
} // namespace

View File

@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "Parser.h"
#include "AsmParserImpl.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
@ -1093,15 +1094,15 @@ Operation *OperationParser::parseGenericOperation(Block *insertBlock,
}
namespace {
class CustomOpAsmParser : public OpAsmParser {
class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
public:
CustomOpAsmParser(
SMLoc nameLoc, ArrayRef<OperationParser::ResultRecord> resultIDs,
function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly,
bool isIsolatedFromAbove, StringRef opName, OperationParser &parser)
: nameLoc(nameLoc), resultIDs(resultIDs), parseAssembly(parseAssembly),
isIsolatedFromAbove(isIsolatedFromAbove), opName(opName),
parser(parser) {
: AsmParserImpl<OpAsmParser>(nameLoc, parser), resultIDs(resultIDs),
parseAssembly(parseAssembly), isIsolatedFromAbove(isIsolatedFromAbove),
opName(opName), parser(parser) {
(void)isIsolatedFromAbove; // Only used in assert, silence unused warning.
}
@ -1131,21 +1132,6 @@ public:
// Utilities
//===--------------------------------------------------------------------===//
/// Return if any errors were emitted during parsing.
bool didEmitError() const { return emittedError; }
/// Emit a diagnostic at the specified location and return failure.
InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
emittedError = true;
return parser.emitError(loc, "custom op '" + opName + "' " + message);
}
llvm::SMLoc getCurrentLocation() override {
return parser.getToken().getLoc();
}
Builder &getBuilder() const override { return parser.builder; }
/// Return the name of the specified result in the specified syntax, as well
/// as the subelement in the name. For example, in this operation:
///
@ -1181,331 +1167,10 @@ public:
return count;
}
llvm::SMLoc getNameLoc() const override { return nameLoc; }
/// Re-encode the given source location as an MLIR location and return it.
Location getEncodedSourceLoc(llvm::SMLoc loc) override {
return parser.getEncodedSourceLocation(loc);
}
//===--------------------------------------------------------------------===//
// Token Parsing
//===--------------------------------------------------------------------===//
/// Parse a `->` token.
ParseResult parseArrow() override {
return parser.parseToken(Token::arrow, "expected '->'");
}
/// Parses a `->` if present.
ParseResult parseOptionalArrow() override {
return success(parser.consumeIf(Token::arrow));
}
/// Parse a '{' token.
ParseResult parseLBrace() override {
return parser.parseToken(Token::l_brace, "expected '{'");
}
/// Parse a '{' token if present
ParseResult parseOptionalLBrace() override {
return success(parser.consumeIf(Token::l_brace));
}
/// Parse a `}` token.
ParseResult parseRBrace() override {
return parser.parseToken(Token::r_brace, "expected '}'");
}
/// Parse a `}` token if present
ParseResult parseOptionalRBrace() override {
return success(parser.consumeIf(Token::r_brace));
}
/// Parse a `:` token.
ParseResult parseColon() override {
return parser.parseToken(Token::colon, "expected ':'");
}
/// Parse a `:` token if present.
ParseResult parseOptionalColon() override {
return success(parser.consumeIf(Token::colon));
}
/// Parse a `,` token.
ParseResult parseComma() override {
return parser.parseToken(Token::comma, "expected ','");
}
/// Parse a `,` token if present.
ParseResult parseOptionalComma() override {
return success(parser.consumeIf(Token::comma));
}
/// Parses a `...` if present.
ParseResult parseOptionalEllipsis() override {
return success(parser.consumeIf(Token::ellipsis));
}
/// Parse a `=` token.
ParseResult parseEqual() override {
return parser.parseToken(Token::equal, "expected '='");
}
/// Parse a `=` token if present.
ParseResult parseOptionalEqual() override {
return success(parser.consumeIf(Token::equal));
}
/// Parse a '<' token.
ParseResult parseLess() override {
return parser.parseToken(Token::less, "expected '<'");
}
/// Parse a '<' token if present.
ParseResult parseOptionalLess() override {
return success(parser.consumeIf(Token::less));
}
/// Parse a '>' token.
ParseResult parseGreater() override {
return parser.parseToken(Token::greater, "expected '>'");
}
/// Parse a '>' token if present.
ParseResult parseOptionalGreater() override {
return success(parser.consumeIf(Token::greater));
}
/// Parse a `(` token.
ParseResult parseLParen() override {
return parser.parseToken(Token::l_paren, "expected '('");
}
/// Parses a '(' if present.
ParseResult parseOptionalLParen() override {
return success(parser.consumeIf(Token::l_paren));
}
/// Parse a `)` token.
ParseResult parseRParen() override {
return parser.parseToken(Token::r_paren, "expected ')'");
}
/// Parses a ')' if present.
ParseResult parseOptionalRParen() override {
return success(parser.consumeIf(Token::r_paren));
}
/// Parse a `[` token.
ParseResult parseLSquare() override {
return parser.parseToken(Token::l_square, "expected '['");
}
/// Parses a '[' if present.
ParseResult parseOptionalLSquare() override {
return success(parser.consumeIf(Token::l_square));
}
/// Parse a `]` token.
ParseResult parseRSquare() override {
return parser.parseToken(Token::r_square, "expected ']'");
}
/// Parses a ']' if present.
ParseResult parseOptionalRSquare() override {
return success(parser.consumeIf(Token::r_square));
}
/// Parses a '?' token.
ParseResult parseQuestion() override {
return parser.parseToken(Token::question, "expected '?'");
}
/// Parses a '?' token if present.
ParseResult parseOptionalQuestion() override {
return success(parser.consumeIf(Token::question));
}
/// Parses a '+' token.
ParseResult parsePlus() override {
return parser.parseToken(Token::plus, "expected '+'");
}
/// Parses a '+' token if present.
ParseResult parseOptionalPlus() override {
return success(parser.consumeIf(Token::plus));
}
/// Parses a '*' token.
ParseResult parseStar() override {
return parser.parseToken(Token::star, "expected '*'");
}
/// Parses a '*' token if present.
ParseResult parseOptionalStar() override {
return success(parser.consumeIf(Token::star));
}
/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalInteger(APInt &result) override {
return parser.parseOptionalInteger(result);
}
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
ParseResult parseCommaSeparatedList(Delimiter delimiter,
function_ref<ParseResult()> parseElt,
StringRef contextMessage) override {
return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
}
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
/// Parse an arbitrary attribute of a given type and return it in result.
ParseResult parseAttribute(Attribute &result, Type type) override {
result = parser.parseAttribute(type);
return success(static_cast<bool>(result));
}
/// Parse an optional attribute.
template <typename AttrT>
OptionalParseResult
parseOptionalAttributeAndAddToList(AttrT &result, Type type,
StringRef attrName, NamedAttrList &attrs) {
OptionalParseResult parseResult =
parser.parseOptionalAttribute(result, type);
if (parseResult.hasValue() && succeeded(*parseResult))
attrs.push_back(parser.builder.getNamedAttr(attrName, result));
return parseResult;
}
OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
OptionalParseResult parseOptionalAttribute(ArrayAttr &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
OptionalParseResult parseOptionalAttribute(StringAttr &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
/// Parse a named dictionary into 'result' if it is present.
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
if (parser.getToken().isNot(Token::l_brace))
return success();
return parser.parseAttributeDict(result);
}
/// Parse a named dictionary into 'result' if the `attributes` keyword is
/// present.
ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result) override {
if (failed(parseOptionalKeyword("attributes")))
return success();
return parser.parseAttributeDict(result);
}
/// Parse an affine map instance into 'map'.
ParseResult parseAffineMap(AffineMap &map) override {
return parser.parseAffineMapReference(map);
}
/// Parse an integer set instance into 'set'.
ParseResult printIntegerSet(IntegerSet &set) override {
return parser.parseIntegerSetReference(set);
}
//===--------------------------------------------------------------------===//
// Identifier Parsing
//===--------------------------------------------------------------------===//
/// Returns true if the current token corresponds to a keyword.
bool isCurrentTokenAKeyword() const {
return parser.getToken().is(Token::bare_identifier) ||
parser.getToken().isKeyword();
}
/// Parse the given keyword if present.
ParseResult parseOptionalKeyword(StringRef keyword) override {
// Check that the current token has the same spelling.
if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword)
return failure();
parser.consumeToken();
return success();
}
/// Parse a keyword, if present, into 'keyword'.
ParseResult parseOptionalKeyword(StringRef *keyword) override {
// Check that the current token is a keyword.
if (!isCurrentTokenAKeyword())
return failure();
*keyword = parser.getTokenSpelling();
parser.consumeToken();
return success();
}
/// Parse a keyword if it is one of the 'allowedKeywords'.
ParseResult
parseOptionalKeyword(StringRef *keyword,
ArrayRef<StringRef> allowedKeywords) override {
// Check that the current token is a keyword.
if (!isCurrentTokenAKeyword())
return failure();
StringRef currentKeyword = parser.getTokenSpelling();
if (llvm::is_contained(allowedKeywords, currentKeyword)) {
*keyword = currentKeyword;
parser.consumeToken();
return success();
}
return failure();
}
/// Parse an optional @-identifier and store it (without the '@' symbol) in a
/// string attribute named 'attrName'.
ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName,
NamedAttrList &attrs) override {
Token atToken = parser.getToken();
if (atToken.isNot(Token::at_identifier))
return failure();
result = getBuilder().getStringAttr(atToken.getSymbolReference());
attrs.push_back(getBuilder().getNamedAttr(attrName, result));
parser.consumeToken();
// If we are populating the assembly parser state, record this as a symbol
// reference.
if (parser.getState().asmState) {
parser.getState().asmState->addUses(SymbolRefAttr::get(result),
atToken.getLocRange());
}
return success();
}
/// Parse a loc(...) specifier if present, filling in result if so.
ParseResult
parseOptionalLocationSpecifier(Optional<Location> &result) override {
// If there is a 'loc' we parse a trailing location.
if (!parser.consumeIf(Token::kw_loc))
return success();
LocationAttr directLoc;
if (parser.parseToken(Token::l_paren, "expected '(' in location") ||
parser.parseLocationInstance(directLoc) ||
parser.parseToken(Token::r_paren, "expected ')' in location"))
return failure();
result = directLoc;
return success();
/// Emit a diagnostic at the specified location and return failure.
InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
return AsmParserImpl<OpAsmParser>::emitError(loc, "custom op '" + opName +
"' " + message);
}
//===--------------------------------------------------------------------===//
@ -1779,53 +1444,6 @@ public:
// Type Parsing
//===--------------------------------------------------------------------===//
/// Parse a type.
ParseResult parseType(Type &result) override {
return failure(!(result = parser.parseType()));
}
/// Parse an optional type.
OptionalParseResult parseOptionalType(Type &result) override {
return parser.parseOptionalType(result);
}
/// Parse an arrow followed by a type list.
ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override {
if (parseArrow() || parser.parseFunctionResultTypes(result))
return failure();
return success();
}
/// Parse an optional arrow followed by a type list.
ParseResult
parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override {
if (!parser.consumeIf(Token::arrow))
return success();
return parser.parseFunctionResultTypes(result);
}
/// Parse a colon followed by a type.
ParseResult parseColonType(Type &result) override {
return failure(parser.parseToken(Token::colon, "expected ':'") ||
!(result = parser.parseType()));
}
/// Parse a colon followed by a type list, which must have at least one type.
ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
if (parser.parseToken(Token::colon, "expected ':'"))
return failure();
return parser.parseTypeListNoParens(result);
}
/// Parse an optional colon followed by a type list, which if present must
/// have at least one type.
ParseResult
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override {
if (!parser.consumeIf(Token::colon))
return success();
return parser.parseTypeListNoParens(result);
}
/// Parse a list of assignments of the form
/// (%x1 = %y1, %x2 = %y2, ...).
OptionalParseResult
@ -1870,9 +1488,6 @@ public:
}
private:
/// The source location of the operation name.
SMLoc nameLoc;
/// Information about the result name specifiers.
ArrayRef<OperationParser::ResultRecord> resultIDs;
@ -1881,11 +1496,8 @@ private:
bool isIsolatedFromAbove;
StringRef opName;
/// The main operation parser.
/// The backing operation parser.
OperationParser &parser;
/// A flag that indicates if any errors were emitted during parsing.
bool emittedError = false;
};
} // end anonymous namespace.