[mlir] Add extensible dialects

Add support for extensible dialects, which are dialects that can be
extended at runtime with new operations and types.

These operations and types cannot at the moment implement traits
or interfaces.

Differential Revision: https://reviews.llvm.org/D104554
This commit is contained in:
Mathieu Fehr 2022-03-02 12:17:06 -08:00 committed by River Riddle
parent 507f7317a0
commit dbe9f0914f
18 changed files with 1941 additions and 40 deletions

View File

@ -0,0 +1,369 @@
# Extensible dialects
This file documents the design and API of the extensible dialects. Extensible
dialects are dialects that can be extended with new operations and types defined
at runtime. This allows for users to define dialects via with meta-programming,
or from another language, without having to recompile C++ code.
[TOC]
## Usage
### Defining an extensible dialect
Dialects defined in C++ can be extended with new operations, types, etc., at
runtime by inheriting from `mlir::ExtensibleDialect` instead of `mlir::Dialect`
(note that `ExtensibleDialect` inherits from `Dialect`). The `ExtensibleDialect`
class contains the necessary fields and methods to extend the dialect at
runtime.
```c++
class MyDialect : public mlir::ExtensibleDialect {
...
}
```
For dialects defined in TableGen, this is done by setting the `isExtensible`
flag to `1`.
```tablegen
def Test_Dialect : Dialect {
let isExtensible = 1;
...
}
```
An extensible `Dialect` can be casted back to `ExtensibleDialect` using
`llvm::dyn_cast`, or `llvm::cast`:
```c++
if (auto extensibleDialect = llvm::dyn_cast<ExtensibleDialect>(dialect)) {
...
}
```
### Defining an operation at runtime
The `DynamicOpDefinition` class represents the definition of an operation
defined at runtime. It is created using the `DynamicOpDefinition::get`
functions. An operation defined at runtime must provide a name, a dialect in
which the operation will be registered in, an operation verifier. It may also
optionally define a custom parser and a printer, fold hook, and more.
```c++
// The operation name, without the dialect name prefix.
StringRef name = "my_operation_name";
// The dialect defining the operation.
Dialect* dialect = ctx->getOrLoadDialect<MyDialect>();
// Operation verifier definition.
AbstractOperation::VerifyInvariantsFn verifyFn = [](Operation* op) {
// Logic for the operation verification.
...
}
// Parser function definition.
AbstractOperation::ParseAssemblyFn parseFn =
[](OpAsmParser &parser, OperationState &state) {
// Parse the operation, given that the name is already parsed.
...
};
// Printer function
auto printFn = [](Operation *op, OpAsmPrinter &printer) {
printer << op->getName();
// Print the operation, given that the name is already printed.
...
};
// General folder implementation, see AbstractOperation::foldHook for more
// information.
auto foldHookFn = [](Operation * op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &result) {
...
};
// Returns any canonicalization pattern rewrites that the operation
// supports, for use by the canonicalization pass.
auto getCanonicalizationPatterns =
[](RewritePatternSet &results, MLIRContext *context) {
...
}
// Definition of the operation.
std::unique_ptr<DynamicOpDefinition> opDef =
DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
std::move(parseFn), std::move(printFn), std::move(foldHookFn),
std::move(getCanonicalizationPatterns));
```
Once the operation is defined, it can be registered by an `ExtensibleDialect`:
```c++
extensibleDialect->registerDynamicOperation(std::move(opDef));
```
Note that the `Dialect` given to the operation should be the one registering
the operation.
### Using an operation defined at runtime
It is possible to match on an operation defined at runtime using their names:
```c++
if (op->getName().getStringRef() == "my_dialect.my_dynamic_op") {
...
}
```
An operation defined at runtime can be created by instantiating an
`OperationState` with the operation name, and using it with a rewriter
(for instance a `PatternRewriter`) to create the operation.
```c++
OperationState state(location, "my_dialect.my_dynamic_op",
operands, resultTypes, attributes);
rewriter.createOperation(state);
```
### Defining a type at runtime
Contrary to types defined in C++ or in TableGen, types defined at runtime can
only have as argument a list of `Attribute`.
Similarily to operations, a type is defined at runtime using the class
`DynamicTypeDefinition`, which is created using the `DynamicTypeDefinition::get`
functions. A type definition requires a name, the dialect that will register the
type, and a parameter verifier. It can also define optionally a custom parser
and printer for the arguments (the type name is assumed to be already
parsed/printed).
```c++
// The type name, without the dialect name prefix.
StringRef name = "my_type_name";
// The dialect defining the type.
Dialect* dialect = ctx->getOrLoadDialect<MyDialect>();
// The type verifier.
// A type defined at runtime has a list of attributes as parameters.
auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> args) {
...
};
// The type parameters parser.
auto parser = [](DialectAsmParser &parser,
llvm::SmallVectorImpl<Attribute> &parsedParams) {
...
};
// The type parameters printer.
auto printer =[](DialectAsmPrinter &printer, ArrayRef<Attribute> params) {
...
};
std::unique_ptr<DynamicTypeDefinition> typeDef =
DynamicTypeDefinition::get(std::move(name), std::move(dialect),
std::move(verifier), std::move(printer),
std::move(parser));
```
If the printer and the parser are ommited, a default parser and printer is
generated with the format `!dialect.typename<arg1, arg2, ..., argN>`.
The type can then be registered by the `ExtensibleDialect`:
```c++
dialect->registerDynamicType(std::move(typeDef));
```
### Parsing types defined at runtime in an extensible dialect
`parseType` methods generated by TableGen can parse types defined at runtime,
though overriden `parseType` methods need to add the necessary support for them.
```c++
Type MyDialect::parseType(DialectAsmParser &parser) const {
...
// The type name.
StringRef typeTag;
if (failed(parser.parseKeyword(&typeTag)))
return Type();
// Try to parse a dynamic type with 'typeTag' name.
Type dynType;
auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
if (parseResult.hasValue()) {
if (succeeded(parseResult.getValue()))
return dynType;
return Type();
}
...
}
```
### Using a type defined at runtime
Dynamic types are instances of `DynamicType`. It is possible to get a dynamic
type with `DynamicType::get` and `ExtensibleDialect::lookupTypeDefinition`.
```c++
auto typeDef = extensibleDialect->lookupTypeDefinition("my_dynamic_type");
ArrayRef<Attribute> params = ...;
auto type = DynamicType::get(typeDef, params);
```
It is also possible to cast a `Type` known to be defined at runtime to a
`DynamicType`.
```c++
auto dynType = type.cast<DynamicType>();
auto typeDef = dynType.getTypeDef();
auto args = dynType.getParams();
```
### Defining an attribute at runtime
Similar to types defined at runtime, attributes defined at runtime can only have
as argument a list of `Attribute`.
Similarily to types, an attribute is defined at runtime using the class
`DynamicAttrDefinition`, which is created using the `DynamicAttrDefinition::get`
functions. An attribute definition requires a name, the dialect that will
register the attribute, and a parameter verifier. It can also define optionally
a custom parser and printer for the arguments (the attribute name is assumed to
be already parsed/printed).
```c++
// The attribute name, without the dialect name prefix.
StringRef name = "my_attribute_name";
// The dialect defining the attribute.
Dialect* dialect = ctx->getOrLoadDialect<MyDialect>();
// The attribute verifier.
// An attribute defined at runtime has a list of attributes as parameters.
auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> args) {
...
};
// The attribute parameters parser.
auto parser = [](DialectAsmParser &parser,
llvm::SmallVectorImpl<Attribute> &parsedParams) {
...
};
// The attribute parameters printer.
auto printer =[](DialectAsmPrinter &printer, ArrayRef<Attribute> params) {
...
};
std::unique_ptr<DynamicAttrDefinition> attrDef =
DynamicAttrDefinition::get(std::move(name), std::move(dialect),
std::move(verifier), std::move(printer),
std::move(parser));
```
If the printer and the parser are ommited, a default parser and printer is
generated with the format `!dialect.attrname<arg1, arg2, ..., argN>`.
The attribute can then be registered by the `ExtensibleDialect`:
```c++
dialect->registerDynamicAttr(std::move(typeDef));
```
### Parsing attributes defined at runtime in an extensible dialect
`parseAttribute` methods generated by TableGen can parse attributes defined at
runtime, though overriden `parseAttribute` methods need to add the necessary
support for them.
```c++
Attribute MyDialect::parseAttribute(DialectAsmParser &parser,
Type type) const override {
...
// The attribute name.
StringRef attrTag;
if (failed(parser.parseKeyword(&attrTag)))
return Attribute();
// Try to parse a dynamic attribute with 'attrTag' name.
Attribute dynAttr;
auto parseResult = parseOptionalDynamicAttr(attrTag, parser, dynAttr);
if (parseResult.hasValue()) {
if (succeeded(parseResult.getValue()))
return dynAttr;
return Attribute();
}
```
### Using an attribute defined at runtime
Similar to types, attributes defined at runtime are instances of `DynamicAttr`.
It is possible to get a dynamic attribute with `DynamicAttr::get` and
`ExtensibleDialect::lookupAttrDefinition`.
```c++
auto attrDef = extensibleDialect->lookupAttrDefinition("my_dynamic_attr");
ArrayRef<Attribute> params = ...;
auto attr = DynamicAttr::get(attrDef, params);
```
It is also possible to cast an `Attribute` known to be defined at runtime to a
`DynamicAttr`.
```c++
auto dynAttr = attr.cast<DynamicAttr>();
auto attrDef = dynAttr.getAttrDef();
auto args = dynAttr.getParams();
```
## Implementation details
### Extensible dialect
The role of extensible dialects is to own the necessary data for defined
operations and types. They also contain the necessary accessors to easily
access them.
In order to cast a `Dialect` back to an `ExtensibleDialect`, we implement the
`IsExtensibleDialect` interface to all `ExtensibleDialect`. The casting is done
by checking if the `Dialect` implements `IsExtensibleDialect` or not.
### Operation representation and registration
Operations are represented in mlir using the `AbstractOperation` class. They are
registered in dialects the same way operations defined in C++ are registered,
which is by calling `AbstractOperation::insert`.
The only difference is that a new `TypeID` needs to be created for each
operation, since operations are not represented by a C++ class. This is done
using a `TypeIDAllocator`, which can allocate a new unique `TypeID` at runtime.
### Type representation and registration
Unlike operations, types need to define a C++ storage class that takes care of
type parameters. They also need to define another C++ class to access that
storage. `DynamicTypeStorage` defines the storage of types defined at runtime,
and `DynamicType` gives access to the storage, as well as defining useful
functions. A `DynamicTypeStorage` contains a list of `Attribute` type
parameters, as well as a pointer to the type definition.
Types are registered using the `Dialect::addType` method, which expect a
`TypeID` that is generated using a `TypeIDAllocator`. The type uniquer also
register the type with the given `TypeID`. This mean that we can reuse our
single `DynamicType` with different `TypeID` to represent the different types
defined at runtime.
Since the different types defined at runtime have different `TypeID`, it is not
possible to use `TypeID` to cast a `Type` into a `DynamicType`. Thus, similar to
`Dialect`, all `DynamicType` define a `IsDynamicTypeTrait`, so casting a `Type`
to a `DynamicType` boils down to querying the `IsDynamicTypeTrait` trait.

View File

@ -45,6 +45,17 @@ public:
T::getTypeID());
}
/// This method is used by Dialect objects to register attributes with
/// custom TypeIDs.
/// The use of this method is in general discouraged in favor of
/// 'get<CustomAttribute>(dialect)'.
static AbstractAttribute get(Dialect &dialect,
detail::InterfaceMap &&interfaceMap,
HasTraitFn &&hasTrait, TypeID typeID) {
return AbstractAttribute(dialect, std::move(interfaceMap),
std::move(hasTrait), typeID);
}
/// Return the dialect this attribute was registered to.
Dialect &getDialect() const { return const_cast<Dialect &>(dialect); }
@ -175,14 +186,22 @@ namespace detail {
// MLIRContext. This class manages all creation and uniquing of attributes.
class AttributeUniquer {
public:
/// Get an uniqued instance of an attribute T.
template <typename T, typename... Args>
static T get(MLIRContext *ctx, Args &&... args) {
return getWithTypeID<T, Args...>(ctx, T::getTypeID(),
std::forward<Args>(args)...);
}
/// Get an uniqued instance of a parametric attribute T.
/// The use of this method is in general discouraged in favor of
/// 'get<T, Args>(ctx, args)'.
template <typename T, typename... Args>
static typename std::enable_if_t<
!std::is_same<typename T::ImplType, AttributeStorage>::value, T>
get(MLIRContext *ctx, Args &&...args) {
getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&... args) {
#ifndef NDEBUG
if (!ctx->getAttributeUniquer().isParametricStorageInitialized(
T::getTypeID()))
if (!ctx->getAttributeUniquer().isParametricStorageInitialized(typeID))
llvm::report_fatal_error(
llvm::Twine("can't create Attribute '") + llvm::getTypeName<T>() +
"' because storage uniquer isn't initialized: the dialect was likely "
@ -190,57 +209,68 @@ public:
"in the Dialect::initialize() method.");
#endif
return ctx->getAttributeUniquer().get<typename T::ImplType>(
[ctx](AttributeStorage *storage) {
initializeAttributeStorage(storage, ctx, T::getTypeID());
[typeID, ctx](AttributeStorage *storage) {
initializeAttributeStorage(storage, ctx, typeID);
// Execute any additional attribute storage initialization with the
// context.
static_cast<typename T::ImplType *>(storage)->initialize(ctx);
},
T::getTypeID(), std::forward<Args>(args)...);
typeID, std::forward<Args>(args)...);
}
/// Get an uniqued instance of a singleton attribute T.
/// The use of this method is in general discouraged in favor of
/// 'get<T, Args>(ctx, args)'.
template <typename T>
static typename std::enable_if_t<
std::is_same<typename T::ImplType, AttributeStorage>::value, T>
get(MLIRContext *ctx) {
getWithTypeID(MLIRContext *ctx, TypeID typeID) {
#ifndef NDEBUG
if (!ctx->getAttributeUniquer().isSingletonStorageInitialized(
T::getTypeID()))
if (!ctx->getAttributeUniquer().isSingletonStorageInitialized(typeID))
llvm::report_fatal_error(
llvm::Twine("can't create Attribute '") + llvm::getTypeName<T>() +
"' because storage uniquer isn't initialized: the dialect was likely "
"not loaded, or the attribute wasn't added with addAttributes<...>() "
"in the Dialect::initialize() method.");
#endif
return ctx->getAttributeUniquer().get<typename T::ImplType>(T::getTypeID());
return ctx->getAttributeUniquer().get<typename T::ImplType>(typeID);
}
template <typename T, typename... Args>
static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl,
Args &&...args) {
Args &&... args) {
assert(impl && "cannot mutate null attribute");
return ctx->getAttributeUniquer().mutate(T::getTypeID(), impl,
std::forward<Args>(args)...);
}
/// Register an attribute instance T with the uniquer.
template <typename T>
static void registerAttribute(MLIRContext *ctx) {
registerAttribute<T>(ctx, T::getTypeID());
}
/// Register a parametric attribute instance T with the uniquer.
/// The use of this method is in general discouraged in favor of
/// 'registerAttribute<T>(ctx)'.
template <typename T>
static typename std::enable_if_t<
!std::is_same<typename T::ImplType, AttributeStorage>::value>
registerAttribute(MLIRContext *ctx) {
registerAttribute(MLIRContext *ctx, TypeID typeID) {
ctx->getAttributeUniquer()
.registerParametricStorageType<typename T::ImplType>(T::getTypeID());
.registerParametricStorageType<typename T::ImplType>(typeID);
}
/// Register a singleton attribute instance T with the uniquer.
/// The use of this method is in general discouraged in favor of
/// 'registerAttribute<T>(ctx)'.
template <typename T>
static typename std::enable_if_t<
std::is_same<typename T::ImplType, AttributeStorage>::value>
registerAttribute(MLIRContext *ctx) {
registerAttribute(MLIRContext *ctx, TypeID typeID) {
ctx->getAttributeUniquer()
.registerSingletonStorageType<typename T::ImplType>(
T::getTypeID(), [ctx](AttributeStorage *storage) {
initializeAttributeStorage(storage, ctx, T::getTypeID());
typeID, [ctx, typeID](AttributeStorage *storage) {
initializeAttributeStorage(storage, ctx, typeID);
});
}

View File

@ -212,6 +212,11 @@ protected:
(void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
}
/// Register an attribute instance with this dialect.
/// The use of this method is in general discouraged in favor of
/// 'addAttributes<CustomAttr>()'.
void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
/// Enable support for unregistered operations.
void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
@ -237,7 +242,6 @@ private:
addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
detail::AttributeUniquer::registerAttribute<T>(context);
}
void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
/// Register a type instance with this dialect.
template <typename T> void addType() {

View File

@ -0,0 +1,542 @@
//===- ExtensibleDialect.h - Extensible dialect -----------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the DynamicOpDefinition class, the DynamicTypeDefinition
// class, and the DynamicAttrDefinition class, which represent respectively
// operations, types, and attributes that can be defined at runtime. They can
// be registered at runtime to an extensible dialect, using the
// ExtensibleDialect class defined in this file.
//
// For a more complete documentation, see
// https://mlir.llvm.org/docs/ExtensibleDialects/ .
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_EXTENSIBLEDIALECT_H
#define MLIR_IR_EXTENSIBLEDIALECT_H
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/StringMap.h"
namespace mlir {
class AsmParser;
class AsmPrinter;
class DynamicAttr;
class DynamicType;
class ExtensibleDialect;
class MLIRContext;
class OptionalParseResult;
class ParseResult;
namespace detail {
struct DynamicAttrStorage;
struct DynamicTypeStorage;
} // namespace detail
//===----------------------------------------------------------------------===//
// Dynamic attribute
//===----------------------------------------------------------------------===//
/// The definition of a dynamic attribute. A dynamic attribute is an attribute
/// that is defined at runtime, and that can be registered at runtime by an
/// extensible dialect (a dialect inheriting ExtensibleDialect). This class
/// stores the parser, the printer, and the verifier of the attribute. Each
/// dynamic attribute definition refers to one instance of this class.
class DynamicAttrDefinition : SelfOwningTypeID {
public:
using VerifierFn = llvm::unique_function<LogicalResult(
function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) const>;
using ParserFn = llvm::unique_function<ParseResult(
AsmParser &parser, llvm::SmallVectorImpl<Attribute> &parsedAttributes)
const>;
using PrinterFn = llvm::unique_function<void(
AsmPrinter &printer, ArrayRef<Attribute> params) const>;
/// Create a new attribute definition at runtime. The attribute is registered
/// only after passing it to the dialect using registerDynamicAttr.
static std::unique_ptr<DynamicAttrDefinition>
get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier);
static std::unique_ptr<DynamicAttrDefinition>
get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier,
ParserFn &&parser, PrinterFn &&printer);
/// Check that the attribute parameters are valid.
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> params) const {
return verifier(emitError, params);
}
/// Return the MLIRContext in which the dynamic attributes are uniqued.
MLIRContext &getContext() const { return *ctx; }
/// Return the name of the attribute, in the format 'attrname' and
/// not 'dialectname.attrname'.
StringRef getName() const { return name; }
/// Return the dialect defining the attribute.
ExtensibleDialect *getDialect() const { return dialect; }
private:
DynamicAttrDefinition(StringRef name, ExtensibleDialect *dialect,
VerifierFn &&verifier, ParserFn &&parser,
PrinterFn &&printer);
/// This constructor should only be used when we need a pointer to
/// the DynamicAttrDefinition in the verifier, the parser, or the printer.
/// The verifier, parser, and printer need thus to be initialized after the
/// constructor.
DynamicAttrDefinition(ExtensibleDialect *dialect, StringRef name);
/// Register the concrete attribute in the attribute Uniquer.
void registerInAttrUniquer();
/// The name should be prefixed with the dialect name followed by '.'.
std::string name;
/// Dialect in which this attribute is defined.
ExtensibleDialect *dialect;
/// The attribute verifier. It checks that the attribute parameters satisfy
/// the invariants.
VerifierFn verifier;
/// The attribute parameters parser. It parses only the parameters, and
/// expects the attribute name to have already been parsed.
ParserFn parser;
/// The attribute parameters printer. It prints only the parameters, and
/// expects the attribute name to have already been printed.
PrinterFn printer;
/// Context in which the concrete attributes are uniqued.
MLIRContext *ctx;
friend ExtensibleDialect;
friend DynamicAttr;
};
/// This trait is used to determine if an attribute is a dynamic attribute or
/// not; it should only be implemented by dynamic attributes.
/// Note: This is only required because dynamic attributes do not have a
/// static/single TypeID.
template <typename ConcreteType>
class IsDynamicAttrTrait
: public AttributeTrait::TraitBase<ConcreteType, IsDynamicAttrTrait> {};
/// A dynamic attribute instance. This is an attribute whose definition is
/// defined at runtime.
/// It is possible to check if an attribute is a dynamic attribute using
/// `my_attr.isa<DynamicAttr>()`, and getting the attribute definition of a
/// dynamic attribute using the `DynamicAttr::getAttrDef` method.
/// All dynamic attributes have the same storage, which is an array of
/// attributes.
class DynamicAttr : public Attribute::AttrBase<DynamicAttr, Attribute,
detail::DynamicAttrStorage,
IsDynamicAttrTrait> {
public:
// Inherit Base constructors.
using Base::Base;
/// Return an instance of a dynamic attribute given a dynamic attribute
/// definition and attribute parameters.
/// This asserts that the attribute verifier succeeded.
static DynamicAttr get(DynamicAttrDefinition *attrDef,
ArrayRef<Attribute> params = {});
/// Return an instance of a dynamic attribute given a dynamic attribute
/// definition and attribute parameters. If the parameters provided are
/// invalid, errors are emitted using the provided location and a null object
/// is returned.
static DynamicAttr getChecked(function_ref<InFlightDiagnostic()> emitError,
DynamicAttrDefinition *attrDef,
ArrayRef<Attribute> params = {});
/// Return the attribute definition of the concrete attribute.
DynamicAttrDefinition *getAttrDef();
/// Return the attribute parameters.
ArrayRef<Attribute> getParams();
/// Check if an attribute is a specific dynamic attribute.
static bool isa(Attribute attr, DynamicAttrDefinition *attrDef) {
return attr.getTypeID() == attrDef->getTypeID();
}
/// Check if an attribute is a dynamic attribute.
static bool classof(Attribute attr);
/// Parse the dynamic attribute parameters and construct the attribute.
/// The parameters are either empty, and nothing is parsed,
/// or they are in the format '<>' or '<attr (,attr)*>'.
static ParseResult parse(AsmParser &parser, DynamicAttrDefinition *attrDef,
DynamicAttr &parsedAttr);
/// Print the dynamic attribute with the format 'attrname' if there is no
/// parameters, or 'attrname<attr (,attr)*>'.
void print(AsmPrinter &printer);
};
//===----------------------------------------------------------------------===//
// Dynamic type
//===----------------------------------------------------------------------===//
/// The definition of a dynamic type. A dynamic type is a type that is
/// defined at runtime, and that can be registered at runtime by an
/// extensible dialect (a dialect inheriting ExtensibleDialect). This class
/// stores the parser, the printer, and the verifier of the type. Each dynamic
/// type definition refers to one instance of this class.
class DynamicTypeDefinition : SelfOwningTypeID {
public:
using VerifierFn = llvm::unique_function<LogicalResult(
function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) const>;
using ParserFn = llvm::unique_function<ParseResult(
AsmParser &parser, llvm::SmallVectorImpl<Attribute> &parsedAttributes)
const>;
using PrinterFn = llvm::unique_function<void(
AsmPrinter &printer, ArrayRef<Attribute> params) const>;
/// Create a new dynamic type definition. The type is registered only after
/// passing it to the dialect using registerDynamicType.
static std::unique_ptr<DynamicTypeDefinition>
get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier);
static std::unique_ptr<DynamicTypeDefinition>
get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier,
ParserFn &&parser, PrinterFn &&printer);
/// Check that the type parameters are valid.
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> params) const {
return verifier(emitError, params);
}
/// Return the MLIRContext in which the dynamic types is uniqued.
MLIRContext &getContext() const { return *ctx; }
/// Return the name of the type, in the format 'typename' and
/// not 'dialectname.typename'.
StringRef getName() const { return name; }
/// Return the dialect defining the type.
ExtensibleDialect *getDialect() const { return dialect; }
private:
DynamicTypeDefinition(StringRef name, ExtensibleDialect *dialect,
VerifierFn &&verifier, ParserFn &&parser,
PrinterFn &&printer);
/// This constructor should only be used when we need a pointer to
/// the DynamicTypeDefinition in the verifier, the parser, or the printer.
/// The verifier, parser, and printer need thus to be initialized after the
/// constructor.
DynamicTypeDefinition(ExtensibleDialect *dialect, StringRef name);
/// Register the concrete type in the type Uniquer.
void registerInTypeUniquer();
/// The name should be prefixed with the dialect name followed by '.'.
std::string name;
/// Dialect in which this type is defined.
ExtensibleDialect *dialect;
/// The type verifier. It checks that the type parameters satisfy the
/// invariants.
VerifierFn verifier;
/// The type parameters parser. It parses only the parameters, and expects the
/// type name to have already been parsed.
ParserFn parser;
/// The type parameters printer. It prints only the parameters, and expects
/// the type name to have already been printed.
PrinterFn printer;
/// Context in which the concrete types are uniqued.
MLIRContext *ctx;
friend ExtensibleDialect;
friend DynamicType;
};
/// This trait is used to determine if a type is a dynamic type or not;
/// it should only be implemented by dynamic types.
/// Note: This is only required because dynamic type do not have a
/// static/single TypeID.
template <typename ConcreteType>
class IsDynamicTypeTrait
: public TypeTrait::TraitBase<ConcreteType, IsDynamicTypeTrait> {};
/// A dynamic type instance. This is a type whose definition is defined at
/// runtime.
/// It is possible to check if a type is a dynamic type using
/// `my_type.isa<DynamicType>()`, and getting the type definition of a dynamic
/// type using the `DynamicType::getTypeDef` method.
/// All dynamic types have the same storage, which is an array of attributes.
class DynamicType
: public Type::TypeBase<DynamicType, Type, detail::DynamicTypeStorage,
IsDynamicTypeTrait> {
public:
// Inherit Base constructors.
using Base::Base;
/// Return an instance of a dynamic type given a dynamic type definition and
/// type parameters.
/// This asserts that the type verifier succeeded.
static DynamicType get(DynamicTypeDefinition *typeDef,
ArrayRef<Attribute> params = {});
/// Return an instance of a dynamic type given a dynamic type definition and
/// type parameters. If the parameters provided are invalid, errors are
/// emitted using the provided location and a null object is returned.
static DynamicType getChecked(function_ref<InFlightDiagnostic()> emitError,
DynamicTypeDefinition *typeDef,
ArrayRef<Attribute> params = {});
/// Return the type definition of the concrete type.
DynamicTypeDefinition *getTypeDef();
/// Return the type parameters.
ArrayRef<Attribute> getParams();
/// Check if a type is a specific dynamic type.
static bool isa(Type type, DynamicTypeDefinition *typeDef) {
return type.getTypeID() == typeDef->getTypeID();
}
/// Check if a type is a dynamic type.
static bool classof(Type type);
/// Parse the dynamic type parameters and construct the type.
/// The parameters are either empty, and nothing is parsed,
/// or they are in the format '<>' or '<attr (,attr)*>'.
static ParseResult parse(AsmParser &parser, DynamicTypeDefinition *typeDef,
DynamicType &parsedType);
/// Print the dynamic type with the format
/// 'type' or 'type<>' if there is no parameters, or 'type<attr (,attr)*>'.
void print(AsmPrinter &printer);
};
//===----------------------------------------------------------------------===//
// Dynamic operation
//===----------------------------------------------------------------------===//
/// The definition of a dynamic op. A dynamic op is an op that is defined at
/// runtime, and that can be registered at runtime by an extensible dialect (a
/// dialect inheriting ExtensibleDialect). This class stores the functions that
/// are in the OperationName class, and in addition defines the TypeID of the op
/// that will be defined.
/// Each dynamic operation definition refers to one instance of this class.
class DynamicOpDefinition {
public:
/// Create a new op at runtime. The op is registered only after passing it to
/// the dialect using registerDynamicOp.
static std::unique_ptr<DynamicOpDefinition>
get(StringRef name, ExtensibleDialect *dialect,
OperationName::VerifyInvariantsFn &&verifyFn,
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn);
static std::unique_ptr<DynamicOpDefinition>
get(StringRef name, ExtensibleDialect *dialect,
OperationName::VerifyInvariantsFn &&verifyFn,
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn);
static std::unique_ptr<DynamicOpDefinition>
get(StringRef name, ExtensibleDialect *dialect,
OperationName::VerifyInvariantsFn &&verifyFn,
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
OperationName::GetCanonicalizationPatternsFn
&&getCanonicalizationPatternsFn);
/// Returns the op typeID.
TypeID getTypeID() { return typeID; }
/// Sets the verifier function for this operation. It should emits an error
/// message and returns failure if a problem is detected, or returns success
/// if everything is ok.
void setVerifyFn(OperationName::VerifyInvariantsFn &&verify) {
verifyFn = std::move(verify);
}
/// Sets the region verifier function for this operation. It should emits an
/// error message and returns failure if a problem is detected, or returns
/// success if everything is ok.
void setVerifyRegionFn(OperationName::VerifyRegionInvariantsFn &&verify) {
verifyRegionFn = std::move(verify);
}
/// Sets the static hook for parsing this op assembly.
void setParseFn(OperationName::ParseAssemblyFn &&parse) {
parseFn = std::move(parse);
}
/// Sets the static hook for printing this op assembly.
void setPrintFn(OperationName::PrintAssemblyFn &&print) {
printFn = std::move(print);
}
/// Sets the hook implementing a generalized folder for the op. See
/// `RegisteredOperationName::foldHook` for more details
void setFoldHookFn(OperationName::FoldHookFn &&foldHook) {
foldHookFn = std::move(foldHook);
}
/// Set the hook returning any canonicalization pattern rewrites that the op
/// supports, for use by the canonicalization pass.
void
setGetCanonicalizationPatternsFn(OperationName::GetCanonicalizationPatternsFn
&&getCanonicalizationPatterns) {
getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
}
private:
DynamicOpDefinition(StringRef name, ExtensibleDialect *dialect,
OperationName::VerifyInvariantsFn &&verifyFn,
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
OperationName::GetCanonicalizationPatternsFn
&&getCanonicalizationPatternsFn);
/// Unique identifier for this operation.
TypeID typeID;
/// Name of the operation.
/// The name is prefixed with the dialect name.
std::string name;
/// Dialect defining this operation.
ExtensibleDialect *dialect;
OperationName::VerifyInvariantsFn verifyFn;
OperationName::VerifyRegionInvariantsFn verifyRegionFn;
OperationName::ParseAssemblyFn parseFn;
OperationName::PrintAssemblyFn printFn;
OperationName::FoldHookFn foldHookFn;
OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
friend ExtensibleDialect;
};
//===----------------------------------------------------------------------===//
// Extensible dialect
//===----------------------------------------------------------------------===//
/// A dialect that can be extended with new operations/types/attributes at
/// runtime.
class ExtensibleDialect : public mlir::Dialect {
public:
ExtensibleDialect(StringRef name, MLIRContext *ctx, TypeID typeID);
/// Add a new type defined at runtime to the dialect.
void registerDynamicType(std::unique_ptr<DynamicTypeDefinition> &&type);
/// Add a new attribute defined at runtime to the dialect.
void registerDynamicAttr(std::unique_ptr<DynamicAttrDefinition> &&attr);
/// Add a new operation defined at runtime to the dialect.
void registerDynamicOp(std::unique_ptr<DynamicOpDefinition> &&type);
/// Check if the dialect is an extensible dialect.
static bool classof(const Dialect *dialect);
/// Returns nullptr if the definition was not found.
DynamicTypeDefinition *lookupTypeDefinition(StringRef name) const {
auto it = nameToDynTypes.find(name);
if (it == nameToDynTypes.end())
return nullptr;
return it->second;
}
/// Returns nullptr if the definition was not found.
DynamicTypeDefinition *lookupTypeDefinition(TypeID id) const {
auto it = dynTypes.find(id);
if (it == dynTypes.end())
return nullptr;
return it->second.get();
}
/// Returns nullptr if the definition was not found.
DynamicAttrDefinition *lookupAttrDefinition(StringRef name) const {
auto it = nameToDynAttrs.find(name);
if (it == nameToDynAttrs.end())
return nullptr;
return it->second;
}
/// Returns nullptr if the definition was not found.
DynamicAttrDefinition *lookupAttrDefinition(TypeID id) const {
auto it = dynAttrs.find(id);
if (it == dynAttrs.end())
return nullptr;
return it->second.get();
}
protected:
/// Parse the dynamic type 'typeName' in the dialect 'dialect'.
/// typename should not be prefixed with the dialect name.
/// If the dynamic type does not exist, return no value.
/// Otherwise, parse it, and return the parse result.
/// If the parsing succeed, put the resulting type in 'resultType'.
OptionalParseResult parseOptionalDynamicType(StringRef typeName,
AsmParser &parser,
Type &resultType) const;
/// If 'type' is a dynamic type, print it.
/// Returns success if the type was printed, and failure if the type was not a
/// dynamic type.
static LogicalResult printIfDynamicType(Type type, AsmPrinter &printer);
/// Parse the dynamic attribute 'attrName' in the dialect 'dialect'.
/// attrname should not be prefixed with the dialect name.
/// If the dynamic attribute does not exist, return no value.
/// Otherwise, parse it, and return the parse result.
/// If the parsing succeed, put the resulting attribute in 'resultAttr'.
OptionalParseResult parseOptionalDynamicAttr(StringRef attrName,
AsmParser &parser,
Attribute &resultAttr) const;
/// If 'attr' is a dynamic attribute, print it.
/// Returns success if the attribute was printed, and failure if the
/// attribute was not a dynamic attribute.
static LogicalResult printIfDynamicAttr(Attribute attr, AsmPrinter &printer);
private:
/// The set of all dynamic types registered.
DenseMap<TypeID, std::unique_ptr<DynamicTypeDefinition>> dynTypes;
/// This structure allows to get in O(1) a dynamic type given its name.
llvm::StringMap<DynamicTypeDefinition *> nameToDynTypes;
/// The set of all dynamic attributes registered.
DenseMap<TypeID, std::unique_ptr<DynamicAttrDefinition>> dynAttrs;
/// This structure allows to get in O(1) a dynamic attribute given its name.
llvm::StringMap<DynamicAttrDefinition *> nameToDynAttrs;
/// Give DynamicOpDefinition access to allocateTypeID.
friend DynamicOpDefinition;
/// Allocates a type ID to uniquify operations.
TypeID allocateTypeID() { return typeIDAllocator.allocate(); }
/// Owns the TypeID generated at runtime for operations.
TypeIDAllocator typeIDAllocator;
};
} // namespace mlir
#endif // MLIR_IR_EXTENSIBLEDIALECT_H

View File

@ -333,6 +333,9 @@ class Dialect {
// UpperCamel) and prefixed with `get` or `set` depending on if it is a getter
// or setter.
int emitAccessorPrefix = kEmitAccessorPrefix_Raw;
// If this dialect can be extended at runtime with new operations or types.
bit isExtensible = 0;
}
//===----------------------------------------------------------------------===//

View File

@ -163,13 +163,22 @@ namespace detail {
/// A utility class to get, or create, unique instances of types within an
/// MLIRContext. This class manages all creation and uniquing of types.
struct TypeUniquer {
/// Get an uniqued instance of a type T.
template <typename T, typename... Args>
static T get(MLIRContext *ctx, Args &&... args) {
return getWithTypeID<T, Args...>(ctx, T::getTypeID(),
std::forward<Args>(args)...);
}
/// Get an uniqued instance of a parametric type T.
/// The use of this method is in general discouraged in favor of
/// 'get<T, Args>(ctx, args)'.
template <typename T, typename... Args>
static typename std::enable_if_t<
!std::is_same<typename T::ImplType, TypeStorage>::value, T>
get(MLIRContext *ctx, Args &&...args) {
getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&... args) {
#ifndef NDEBUG
if (!ctx->getTypeUniquer().isParametricStorageInitialized(T::getTypeID()))
if (!ctx->getTypeUniquer().isParametricStorageInitialized(typeID))
llvm::report_fatal_error(
llvm::Twine("can't create type '") + llvm::getTypeName<T>() +
"' because storage uniquer isn't initialized: the dialect was likely "
@ -177,25 +186,27 @@ struct TypeUniquer {
"in the Dialect::initialize() method.");
#endif
return ctx->getTypeUniquer().get<typename T::ImplType>(
[&](TypeStorage *storage) {
storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
[&, typeID](TypeStorage *storage) {
storage->initialize(AbstractType::lookup(typeID, ctx));
},
T::getTypeID(), std::forward<Args>(args)...);
typeID, std::forward<Args>(args)...);
}
/// Get an uniqued instance of a singleton type T.
/// The use of this method is in general discouraged in favor of
/// 'get<T, Args>(ctx, args)'.
template <typename T>
static typename std::enable_if_t<
std::is_same<typename T::ImplType, TypeStorage>::value, T>
get(MLIRContext *ctx) {
getWithTypeID(MLIRContext *ctx, TypeID typeID) {
#ifndef NDEBUG
if (!ctx->getTypeUniquer().isSingletonStorageInitialized(T::getTypeID()))
if (!ctx->getTypeUniquer().isSingletonStorageInitialized(typeID))
llvm::report_fatal_error(
llvm::Twine("can't create type '") + llvm::getTypeName<T>() +
"' because storage uniquer isn't initialized: the dialect was likely "
"not loaded, or the type wasn't added with addTypes<...>() "
"in the Dialect::initialize() method.");
#endif
return ctx->getTypeUniquer().get<typename T::ImplType>(T::getTypeID());
return ctx->getTypeUniquer().get<typename T::ImplType>(typeID);
}
/// Change the mutable component of the given type instance in the provided
@ -208,22 +219,32 @@ struct TypeUniquer {
std::forward<Args>(args)...);
}
/// Register a type instance T with the uniquer.
template <typename T>
static void registerType(MLIRContext *ctx) {
registerType<T>(ctx, T::getTypeID());
}
/// Register a parametric type instance T with the uniquer.
/// The use of this method is in general discouraged in favor of
/// 'registerType<T>(ctx)'.
template <typename T>
static typename std::enable_if_t<
!std::is_same<typename T::ImplType, TypeStorage>::value>
registerType(MLIRContext *ctx) {
registerType(MLIRContext *ctx, TypeID typeID) {
ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>(
T::getTypeID());
typeID);
}
/// Register a singleton type instance T with the uniquer.
/// The use of this method is in general discouraged in favor of
/// 'registerType<T>(ctx)'.
template <typename T>
static typename std::enable_if_t<
std::is_same<typename T::ImplType, TypeStorage>::value>
registerType(MLIRContext *ctx) {
registerType(MLIRContext *ctx, TypeID typeID) {
ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>(
T::getTypeID(), [&](TypeStorage *storage) {
storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
typeID, [&ctx, typeID](TypeStorage *storage) {
storage->initialize(AbstractType::lookup(typeID, ctx));
});
}
};

View File

@ -82,6 +82,10 @@ public:
/// type printing/parsing.
bool useDefaultTypePrinterParser() const;
/// Returns true if this dialect can be extended at runtime with new
/// operations or types.
bool isExtensible() const;
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;

View File

@ -13,6 +13,7 @@ add_mlir_library(MLIRIR
Diagnostics.cpp
Dialect.cpp
Dominance.cpp
ExtensibleDialect.cpp
FunctionImplementation.cpp
FunctionInterfaces.cpp
IntegerSet.cpp

View File

@ -0,0 +1,513 @@
//===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/AttributeSupport.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/StorageUniquerSupport.h"
#include "mlir/Support/LogicalResult.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Dynamic types and attributes shared functions
//===----------------------------------------------------------------------===//
/// Default parser for dynamic attribute or type parameters.
/// Parse in the format '(<>)?' or '<attr (,attr)*>'.
static LogicalResult
typeOrAttrParser(AsmParser &parser, SmallVectorImpl<Attribute> &parsedParams) {
// No parameters
if (parser.parseOptionalLess() || !parser.parseOptionalGreater())
return success();
Attribute attr;
if (parser.parseAttribute(attr))
return failure();
parsedParams.push_back(attr);
while (parser.parseOptionalGreater()) {
Attribute attr;
if (parser.parseComma() || parser.parseAttribute(attr))
return failure();
parsedParams.push_back(attr);
}
return success();
}
/// Default printer for dynamic attribute or type parameters.
/// Print in the format '(<>)?' or '<attr (,attr)*>'.
static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef<Attribute> params) {
if (params.empty())
return;
printer << "<";
interleaveComma(params, printer.getStream());
printer << ">";
}
//===----------------------------------------------------------------------===//
// Dynamic type
//===----------------------------------------------------------------------===//
std::unique_ptr<DynamicTypeDefinition>
DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
VerifierFn &&verifier) {
return DynamicTypeDefinition::get(name, dialect, std::move(verifier),
typeOrAttrParser, typeOrAttrPrinter);
}
std::unique_ptr<DynamicTypeDefinition>
DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
VerifierFn &&verifier, ParserFn &&parser,
PrinterFn &&printer) {
return std::unique_ptr<DynamicTypeDefinition>(
new DynamicTypeDefinition(name, dialect, std::move(verifier),
std::move(parser), std::move(printer)));
}
DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef,
ExtensibleDialect *dialect,
VerifierFn &&verifier,
ParserFn &&parser,
PrinterFn &&printer)
: name(nameRef), dialect(dialect), verifier(std::move(verifier)),
parser(std::move(parser)), printer(std::move(printer)),
ctx(dialect->getContext()) {
assert(!nameRef.contains('.') &&
"name should not be prefixed by the dialect name");
}
DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect,
StringRef nameRef)
: name(nameRef), dialect(dialect), ctx(dialect->getContext()) {
assert(!nameRef.contains('.') &&
"name should not be prefixed by the dialect name");
}
void DynamicTypeDefinition::registerInTypeUniquer() {
detail::TypeUniquer::registerType<DynamicType>(&getContext(), getTypeID());
}
namespace mlir {
namespace detail {
/// Storage of DynamicType.
/// Contains a pointer to the type definition and type parameters.
struct DynamicTypeStorage : public TypeStorage {
using KeyTy = std::pair<DynamicTypeDefinition *, ArrayRef<Attribute>>;
explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef,
ArrayRef<Attribute> params)
: typeDef(typeDef), params(params) {}
bool operator==(const KeyTy &key) const {
return typeDef == key.first && params == key.second;
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static DynamicTypeStorage *construct(TypeStorageAllocator &alloc,
const KeyTy &key) {
return new (alloc.allocate<DynamicTypeStorage>())
DynamicTypeStorage(key.first, alloc.copyInto(key.second));
}
/// Definition of the type.
DynamicTypeDefinition *typeDef;
/// The type parameters.
ArrayRef<Attribute> params;
};
} // namespace detail
} // namespace mlir
DynamicType DynamicType::get(DynamicTypeDefinition *typeDef,
ArrayRef<Attribute> params) {
auto &ctx = typeDef->getContext();
auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx);
assert(succeeded(typeDef->verify(emitError, params)));
return detail::TypeUniquer::getWithTypeID<DynamicType>(
&ctx, typeDef->getTypeID(), typeDef, params);
}
DynamicType
DynamicType::getChecked(function_ref<InFlightDiagnostic()> emitError,
DynamicTypeDefinition *typeDef,
ArrayRef<Attribute> params) {
if (failed(typeDef->verify(emitError, params)))
return {};
auto &ctx = typeDef->getContext();
return detail::TypeUniquer::getWithTypeID<DynamicType>(
&ctx, typeDef->getTypeID(), typeDef, params);
}
DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; }
ArrayRef<Attribute> DynamicType::getParams() { return getImpl()->params; }
bool DynamicType::classof(Type type) {
return type.hasTrait<IsDynamicTypeTrait>();
}
ParseResult DynamicType::parse(AsmParser &parser,
DynamicTypeDefinition *typeDef,
DynamicType &parsedType) {
SmallVector<Attribute> params;
if (failed(typeDef->parser(parser, params)))
return failure();
parsedType = parser.getChecked<DynamicType>(typeDef, params);
if (!parsedType)
return failure();
return success();
}
void DynamicType::print(AsmPrinter &printer) {
printer << getTypeDef()->getName();
getTypeDef()->printer(printer, getParams());
}
//===----------------------------------------------------------------------===//
// Dynamic attribute
//===----------------------------------------------------------------------===//
std::unique_ptr<DynamicAttrDefinition>
DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
VerifierFn &&verifier) {
return DynamicAttrDefinition::get(name, dialect, std::move(verifier),
typeOrAttrParser, typeOrAttrPrinter);
}
std::unique_ptr<DynamicAttrDefinition>
DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
VerifierFn &&verifier, ParserFn &&parser,
PrinterFn &&printer) {
return std::unique_ptr<DynamicAttrDefinition>(
new DynamicAttrDefinition(name, dialect, std::move(verifier),
std::move(parser), std::move(printer)));
}
DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef,
ExtensibleDialect *dialect,
VerifierFn &&verifier,
ParserFn &&parser,
PrinterFn &&printer)
: name(nameRef), dialect(dialect), verifier(std::move(verifier)),
parser(std::move(parser)), printer(std::move(printer)),
ctx(dialect->getContext()) {
assert(!nameRef.contains('.') &&
"name should not be prefixed by the dialect name");
}
DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect,
StringRef nameRef)
: name(nameRef), dialect(dialect), ctx(dialect->getContext()) {
assert(!nameRef.contains('.') &&
"name should not be prefixed by the dialect name");
}
void DynamicAttrDefinition::registerInAttrUniquer() {
detail::AttributeUniquer::registerAttribute<DynamicAttr>(&getContext(),
getTypeID());
}
namespace mlir {
namespace detail {
/// Storage of DynamicAttr.
/// Contains a pointer to the attribute definition and attribute parameters.
struct DynamicAttrStorage : public AttributeStorage {
using KeyTy = std::pair<DynamicAttrDefinition *, ArrayRef<Attribute>>;
explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef,
ArrayRef<Attribute> params)
: attrDef(attrDef), params(params) {}
bool operator==(const KeyTy &key) const {
return attrDef == key.first && params == key.second;
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc,
const KeyTy &key) {
return new (alloc.allocate<DynamicAttrStorage>())
DynamicAttrStorage(key.first, alloc.copyInto(key.second));
}
/// Definition of the type.
DynamicAttrDefinition *attrDef;
/// The type parameters.
ArrayRef<Attribute> params;
};
} // namespace detail
} // namespace mlir
DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef,
ArrayRef<Attribute> params) {
auto &ctx = attrDef->getContext();
return detail::AttributeUniquer::getWithTypeID<DynamicAttr>(
&ctx, attrDef->getTypeID(), attrDef, params);
}
DynamicAttr
DynamicAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
DynamicAttrDefinition *attrDef,
ArrayRef<Attribute> params) {
if (failed(attrDef->verify(emitError, params)))
return {};
return get(attrDef, params);
}
DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; }
ArrayRef<Attribute> DynamicAttr::getParams() { return getImpl()->params; }
bool DynamicAttr::classof(Attribute attr) {
return attr.hasTrait<IsDynamicAttrTrait>();
}
ParseResult DynamicAttr::parse(AsmParser &parser,
DynamicAttrDefinition *attrDef,
DynamicAttr &parsedAttr) {
SmallVector<Attribute> params;
if (failed(attrDef->parser(parser, params)))
return failure();
parsedAttr = parser.getChecked<DynamicAttr>(attrDef, params);
if (!parsedAttr)
return failure();
return success();
}
void DynamicAttr::print(AsmPrinter &printer) {
printer << getAttrDef()->getName();
getAttrDef()->printer(printer, getParams());
}
//===----------------------------------------------------------------------===//
// Dynamic operation
//===----------------------------------------------------------------------===//
DynamicOpDefinition::DynamicOpDefinition(
StringRef name, ExtensibleDialect *dialect,
OperationName::VerifyInvariantsFn &&verifyFn,
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
OperationName::GetCanonicalizationPatternsFn
&&getCanonicalizationPatternsFn)
: typeID(dialect->allocateTypeID()),
name((dialect->getNamespace() + "." + name).str()), dialect(dialect),
verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
parseFn(std::move(parseFn)), printFn(std::move(printFn)),
foldHookFn(std::move(foldHookFn)),
getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) {
assert(!name.contains('.') &&
"name should not be prefixed by the dialect name");
}
std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
StringRef name, ExtensibleDialect *dialect,
OperationName::VerifyInvariantsFn &&verifyFn,
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) {
auto parseFn = [](OpAsmParser &parser, OperationState &result) {
return parser.emitError(
parser.getCurrentLocation(),
"dynamic operation do not define any parser function");
};
auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) {
printer.printGenericOp(op);
};
return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
std::move(verifyRegionFn), std::move(parseFn),
std::move(printFn));
}
std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
StringRef name, ExtensibleDialect *dialect,
OperationName::VerifyInvariantsFn &&verifyFn,
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn) {
auto foldHookFn = [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return failure();
};
auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
};
return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
std::move(verifyRegionFn), std::move(parseFn),
std::move(printFn), std::move(foldHookFn),
std::move(getCanonicalizationPatternsFn));
}
std::unique_ptr<DynamicOpDefinition>
DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect,
OperationName::VerifyInvariantsFn &&verifyFn,
OperationName::VerifyInvariantsFn &&verifyRegionFn,
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
OperationName::GetCanonicalizationPatternsFn
&&getCanonicalizationPatternsFn) {
return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
std::move(parseFn), std::move(printFn), std::move(foldHookFn),
std::move(getCanonicalizationPatternsFn)));
}
//===----------------------------------------------------------------------===//
// Extensible dialect
//===----------------------------------------------------------------------===//
namespace {
/// Interface that can only be implemented by extensible dialects.
/// The interface is used to check if a dialect is extensible or not.
class IsExtensibleDialect : public DialectInterface::Base<IsExtensibleDialect> {
public:
IsExtensibleDialect(Dialect *dialect) : Base(dialect) {}
};
} // namespace
ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx,
TypeID typeID)
: Dialect(name, ctx, typeID) {
addInterfaces<IsExtensibleDialect>();
}
void ExtensibleDialect::registerDynamicType(
std::unique_ptr<DynamicTypeDefinition> &&type) {
DynamicTypeDefinition *typePtr = type.get();
TypeID typeID = type->getTypeID();
StringRef name = type->getName();
ExtensibleDialect *dialect = type->getDialect();
assert(dialect == this &&
"trying to register a dynamic type in the wrong dialect");
// If a type with the same name is already defined, fail.
auto registered = dynTypes.try_emplace(typeID, std::move(type)).second;
(void)registered;
assert(registered && "type TypeID was not unique");
registered = nameToDynTypes.insert({name, typePtr}).second;
(void)registered;
assert(registered &&
"Trying to create a new dynamic type with an existing name");
auto abstractType =
AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(),
DynamicType::getHasTraitFn(), typeID);
/// Add the type to the dialect and the type uniquer.
addType(typeID, std::move(abstractType));
typePtr->registerInTypeUniquer();
}
void ExtensibleDialect::registerDynamicAttr(
std::unique_ptr<DynamicAttrDefinition> &&attr) {
auto *attrPtr = attr.get();
auto typeID = attr->getTypeID();
auto name = attr->getName();
auto *dialect = attr->getDialect();
assert(dialect == this &&
"trying to register a dynamic attribute in the wrong dialect");
// If an attribute with the same name is already defined, fail.
auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second;
(void)registered;
assert(registered && "attribute TypeID was not unique");
registered = nameToDynAttrs.insert({name, attrPtr}).second;
(void)registered;
assert(registered &&
"Trying to create a new dynamic attribute with an existing name");
auto abstractAttr =
AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(),
DynamicAttr::getHasTraitFn(), typeID);
/// Add the type to the dialect and the type uniquer.
addAttribute(typeID, std::move(abstractAttr));
attrPtr->registerInAttrUniquer();
}
void ExtensibleDialect::registerDynamicOp(
std::unique_ptr<DynamicOpDefinition> &&op) {
assert(op->dialect == this &&
"trying to register a dynamic op in the wrong dialect");
auto hasTraitFn = [](TypeID traitId) { return false; };
RegisteredOperationName::insert(
op->name, *op->dialect, op->typeID, std::move(op->parseFn),
std::move(op->printFn), std::move(op->verifyFn),
std::move(op->verifyRegionFn), std::move(op->foldHookFn),
std::move(op->getCanonicalizationPatternsFn),
detail::InterfaceMap::get<>(), std::move(hasTraitFn), {});
}
bool ExtensibleDialect::classof(const Dialect *dialect) {
return const_cast<Dialect *>(dialect)
->getRegisteredInterface<IsExtensibleDialect>();
}
OptionalParseResult ExtensibleDialect::parseOptionalDynamicType(
StringRef typeName, AsmParser &parser, Type &resultType) const {
DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName);
if (!typeDef)
return llvm::None;
DynamicType dynType;
if (DynamicType::parse(parser, typeDef, dynType))
return failure();
resultType = dynType;
return success();
}
LogicalResult ExtensibleDialect::printIfDynamicType(Type type,
AsmPrinter &printer) {
if (auto dynType = type.dyn_cast<DynamicType>()) {
dynType.print(printer);
return success();
}
return failure();
}
OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr(
StringRef attrName, AsmParser &parser, Attribute &resultAttr) const {
DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName);
if (!attrDef)
return llvm::None;
DynamicAttr dynAttr;
if (DynamicAttr::parse(parser, attrDef, dynAttr))
return failure();
resultAttr = dynAttr;
return success();
}
LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute,
AsmPrinter &printer) {
if (auto dynAttr = attribute.dyn_cast<DynamicAttr>()) {
dynAttr.print(printer);
return success();
}
return failure();
}

View File

@ -102,9 +102,14 @@ Dialect::EmitPrefix Dialect::getEmitAccessorPrefix() const {
int prefix = def->getValueAsInt("emitAccessorPrefix");
if (prefix < 0 || prefix > static_cast<int>(EmitPrefix::Both))
PrintFatalError(def->getLoc(), "Invalid accessor prefix value");
return static_cast<EmitPrefix>(prefix);
}
bool Dialect::isExtensible() const {
return def->getValueAsBit("isExtensible");
}
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}

126
mlir/test/IR/dynamic.mlir Normal file
View File

@ -0,0 +1,126 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics | FileCheck %s
// Verify that extensible dialects can register dynamic operations and types.
//===----------------------------------------------------------------------===//
// Dynamic type
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @succeededDynamicTypeVerifier
func @succeededDynamicTypeVerifier() {
// CHECK: %{{.*}} = "unregistered_op"() : () -> !test.singleton_dyntype
"unregistered_op"() : () -> !test.singleton_dyntype
// CHECK-NEXT: "unregistered_op"() : () -> !test.pair_dyntype<i32, f64>
"unregistered_op"() : () -> !test.pair_dyntype<i32, f64>
// CHECK_NEXT: %{{.*}} = "unregistered_op"() : () -> !test.pair_dyntype<!test.pair_dyntype<i32, f64>, !test.singleton_dyntype>
"unregistered_op"() : () -> !test.pair_dyntype<!test.pair_dyntype<i32, f64>, !test.singleton_dyntype>
return
}
// -----
func @failedDynamicTypeVerifier() {
// expected-error@+1 {{expected 0 type arguments, but had 1}}
"unregistered_op"() : () -> !test.singleton_dyntype<f64>
return
}
// -----
func @failedDynamicTypeVerifier2() {
// expected-error@+1 {{expected 2 type arguments, but had 1}}
"unregistered_op"() : () -> !test.pair_dyntype<f64>
return
}
// -----
// CHECK-LABEL: func @customTypeParserPrinter
func @customTypeParserPrinter() {
// CHECK: "unregistered_op"() : () -> !test.custom_assembly_format_dyntype<f32:f64>
"unregistered_op"() : () -> !test.custom_assembly_format_dyntype<f32 : f64>
return
}
// -----
//===----------------------------------------------------------------------===//
// Dynamic attribute
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @succeededDynamicAttributeVerifier
func @succeededDynamicAttributeVerifier() {
// CHECK: "unregistered_op"() {test_attr = #test.singleton_dynattr} : () -> ()
"unregistered_op"() {test_attr = #test.singleton_dynattr} : () -> ()
// CHECK-NEXT: "unregistered_op"() {test_attr = #test.pair_dynattr<3 : i32, 5 : i32>} : () -> ()
"unregistered_op"() {test_attr = #test.pair_dynattr<3 : i32, 5 : i32>} : () -> ()
// CHECK_NEXT: "unregistered_op"() {test_attr = #test.pair_dynattr<3 : i32, 5 : i32>} : () -> ()
"unregistered_op"() {test_attr = #test.pair_dynattr<#test.pair_dynattr<3 : i32, 5 : i32>, f64>} : () -> ()
return
}
// -----
func @failedDynamicAttributeVerifier() {
// expected-error@+1 {{expected 0 attribute arguments, but had 1}}
"unregistered_op"() {test_attr = #test.singleton_dynattr<f64>} : () -> ()
return
}
// -----
func @failedDynamicAttributeVerifier2() {
// expected-error@+1 {{expected 2 attribute arguments, but had 1}}
"unregistered_op"() {test_attr = #test.pair_dynattr<f64> : () -> ()
return
}
// -----
// CHECK-LABEL: func @customAttributeParserPrinter
func @customAttributeParserPrinter() {
// CHECK: "unregistered_op"() {test_attr = #test.custom_assembly_format_dynattr<f32:f64>} : () -> ()
"unregistered_op"() {test_attr = #test.custom_assembly_format_dynattr<f32:f64>} : () -> ()
return
}
//===----------------------------------------------------------------------===//
// Dynamic op
//===----------------------------------------------------------------------===//
// -----
// CHECK-LABEL: func @succeededDynamicOpVerifier
func @succeededDynamicOpVerifier(%a: f32) {
// CHECK: "test.generic_dynamic_op"() : () -> ()
// CHECK-NEXT: %{{.*}} = "test.generic_dynamic_op"(%{{.*}}) : (f32) -> f64
// CHECK-NEXT: %{{.*}}:2 = "test.one_operand_two_results"(%{{.*}}) : (f32) -> (f64, f64)
"test.generic_dynamic_op"() : () -> ()
"test.generic_dynamic_op"(%a) : (f32) -> f64
"test.one_operand_two_results"(%a) : (f32) -> (f64, f64)
return
}
// -----
func @failedDynamicOpVerifier() {
// expected-error@+1 {{expected 1 operand, but had 0}}
"test.one_operand_two_results"() : () -> (f64, f64)
return
}
// -----
func @failedDynamicOpVerifier2(%a: f32) {
// expected-error@+1 {{expected 2 results, but had 0}}
"test.one_operand_two_results"(%a) : (f32) -> ()
return
}
// -----
// CHECK-LABEL: func @customOpParserPrinter
func @customOpParserPrinter() {
// CHECK: test.custom_parser_printer_dynamic_op custom_keyword
test.custom_parser_printer_dynamic_op custom_keyword
return
}

View File

@ -15,12 +15,14 @@
#include "TestDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
using namespace test;
@ -216,6 +218,74 @@ SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
#define GET_ATTRDEF_CLASSES
#include "TestAttrDefs.cpp.inc"
//===----------------------------------------------------------------------===//
// Dynamic Attributes
//===----------------------------------------------------------------------===//
/// Define a singleton dynamic attribute.
static std::unique_ptr<DynamicAttrDefinition>
getSingletonDynamicAttr(TestDialect *testDialect) {
return DynamicAttrDefinition::get(
"singleton_dynattr", testDialect,
[](function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> args) {
if (!args.empty()) {
emitError() << "expected 0 attribute arguments, but had "
<< args.size();
return failure();
}
return success();
});
}
/// Define a dynamic attribute representing a pair or attributes.
static std::unique_ptr<DynamicAttrDefinition>
getPairDynamicAttr(TestDialect *testDialect) {
return DynamicAttrDefinition::get(
"pair_dynattr", testDialect,
[](function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> args) {
if (args.size() != 2) {
emitError() << "expected 2 attribute arguments, but had "
<< args.size();
return failure();
}
return success();
});
}
static std::unique_ptr<DynamicAttrDefinition>
getCustomAssemblyFormatDynamicAttr(TestDialect *testDialect) {
auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> args) {
if (args.size() != 2) {
emitError() << "expected 2 attribute arguments, but had " << args.size();
return failure();
}
return success();
};
auto parser = [](AsmParser &parser,
llvm::SmallVectorImpl<Attribute> &parsedParams) {
Attribute leftAttr, rightAttr;
if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
parser.parseColon() || parser.parseAttribute(rightAttr) ||
parser.parseGreater())
return failure();
parsedParams.push_back(leftAttr);
parsedParams.push_back(rightAttr);
return success();
};
auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
printer << "<" << params[0] << ":" << params[1] << ">";
};
return DynamicAttrDefinition::get("custom_assembly_format_dynattr",
testDialect, std::move(verifier),
std::move(parser), std::move(printer));
}
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
@ -225,4 +295,7 @@ void TestDialect::registerAttributes() {
#define GET_ATTRDEF_LIST
#include "TestAttrDefs.cpp.inc"
>();
registerDynamicAttr(getSingletonDynamicAttr(this));
registerDynamicAttr(getPairDynamicAttr(this));
registerDynamicAttr(getCustomAssemblyFormatDynamicAttr(this));
}

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
@ -206,6 +207,58 @@ public:
} // namespace
//===----------------------------------------------------------------------===//
// Dynamic operations
//===----------------------------------------------------------------------===//
std::unique_ptr<DynamicOpDefinition> getGenericDynamicOp(TestDialect *dialect) {
return DynamicOpDefinition::get(
"generic_dynamic_op", dialect, [](Operation *op) { return success(); },
[](Operation *op) { return success(); });
}
std::unique_ptr<DynamicOpDefinition>
getOneOperandTwoResultsDynamicOp(TestDialect *dialect) {
return DynamicOpDefinition::get(
"one_operand_two_results", dialect,
[](Operation *op) {
if (op->getNumOperands() != 1) {
op->emitOpError()
<< "expected 1 operand, but had " << op->getNumOperands();
return failure();
}
if (op->getNumResults() != 2) {
op->emitOpError()
<< "expected 2 results, but had " << op->getNumResults();
return failure();
}
return success();
},
[](Operation *op) { return success(); });
}
std::unique_ptr<DynamicOpDefinition>
getCustomParserPrinterDynamicOp(TestDialect *dialect) {
auto verifier = [](Operation *op) {
if (op->getNumOperands() == 0 && op->getNumResults() == 0)
return success();
op->emitError() << "operation should have no operands and no results";
return failure();
};
auto regionVerifier = [](Operation *op) { return success(); };
auto parser = [](OpAsmParser &parser, OperationState &state) {
return parser.parseKeyword("custom_keyword");
};
auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
printer << op->getName() << " custom_keyword";
};
return DynamicOpDefinition::get("custom_parser_printer_dynamic_op", dialect,
verifier, regionVerifier, parser, printer);
}
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
@ -240,6 +293,10 @@ void TestDialect::initialize() {
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
registerDynamicOp(getGenericDynamicOp(this));
registerDynamicOp(getOneOperandTwoResultsDynamicOp(this));
registerDynamicOp(getCustomParserPrinterDynamicOp(this));
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
TestInlinerInterface, TestReductionPatternInterface>();
allowUnknownOperations();

View File

@ -23,6 +23,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/RegionKindInterface.h"

View File

@ -23,6 +23,8 @@ def Test_Dialect : Dialect {
let hasOperationInterfaceFallback = 1;
let hasNonDefaultDestructor = 1;
let useDefaultTypePrinterParser = 0;
let useDefaultAttributePrinterParser = 1;
let isExtensible = 1;
let dependentDialects = ["::mlir::DLTIDialect"];
let extraClassDeclaration = [{
@ -43,6 +45,10 @@ def Test_Dialect : Dialect {
// Storage for a custom fallback interface.
void *fallbackEffectOpInterfaces;
::mlir::Type parseTestType(::mlir::AsmParser &parser,
::llvm::SetVector<::mlir::Type> &stack) const;
void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer,
::llvm::SetVector<::mlir::Type> &stack) const;
}];
}

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SetVector.h"
@ -215,6 +216,72 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.cpp.inc"
//===----------------------------------------------------------------------===//
// Dynamic Types
//===----------------------------------------------------------------------===//
/// Define a singleton dynamic type.
static std::unique_ptr<DynamicTypeDefinition>
getSingletonDynamicType(TestDialect *testDialect) {
return DynamicTypeDefinition::get(
"singleton_dyntype", testDialect,
[](function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> args) {
if (!args.empty()) {
emitError() << "expected 0 type arguments, but had " << args.size();
return failure();
}
return success();
});
}
/// Define a dynamic type representing a pair.
static std::unique_ptr<DynamicTypeDefinition>
getPairDynamicType(TestDialect *testDialect) {
return DynamicTypeDefinition::get(
"pair_dyntype", testDialect,
[](function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> args) {
if (args.size() != 2) {
emitError() << "expected 2 type arguments, but had " << args.size();
return failure();
}
return success();
});
}
static std::unique_ptr<DynamicTypeDefinition>
getCustomAssemblyFormatDynamicType(TestDialect *testDialect) {
auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> args) {
if (args.size() != 2) {
emitError() << "expected 2 type arguments, but had " << args.size();
return failure();
}
return success();
};
auto parser = [](AsmParser &parser,
llvm::SmallVectorImpl<Attribute> &parsedParams) {
Attribute leftAttr, rightAttr;
if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
parser.parseColon() || parser.parseAttribute(rightAttr) ||
parser.parseGreater())
return failure();
parsedParams.push_back(leftAttr);
parsedParams.push_back(rightAttr);
return success();
};
auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
printer << "<" << params[0] << ":" << params[1] << ">";
};
return DynamicTypeDefinition::get("custom_assembly_format_dyntype",
testDialect, std::move(verifier),
std::move(parser), std::move(printer));
}
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
@ -232,9 +299,14 @@ void TestDialect::registerTypes() {
#include "TestTypeDefs.cpp.inc"
>();
SimpleAType::attachInterface<PtrElementModel>(*getContext());
registerDynamicType(getSingletonDynamicType(this));
registerDynamicType(getPairDynamicType(this));
registerDynamicType(getCustomAssemblyFormatDynamicType(this));
}
static Type parseTestType(AsmParser &parser, SetVector<Type> &stack) {
Type TestDialect::parseTestType(AsmParser &parser,
SetVector<Type> &stack) const {
StringRef typeTag;
if (failed(parser.parseKeyword(&typeTag)))
return Type();
@ -246,6 +318,16 @@ static Type parseTestType(AsmParser &parser, SetVector<Type> &stack) {
return genType;
}
{
Type dynType;
auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
if (parseResult.hasValue()) {
if (succeeded(parseResult.getValue()))
return dynType;
return Type();
}
}
if (typeTag != "test_rec") {
parser.emitError(parser.getNameLoc()) << "unknown type!";
return Type();
@ -281,11 +363,14 @@ Type TestDialect::parseType(DialectAsmParser &parser) const {
return parseTestType(parser, stack);
}
static void printTestType(Type type, AsmPrinter &printer,
SetVector<Type> &stack) {
void TestDialect::printTestType(Type type, AsmPrinter &printer,
SetVector<Type> &stack) const {
if (succeeded(generatedTypePrinter(type, printer)))
return;
if (succeeded(printIfDynamicType(type, printer)))
return;
auto rec = type.cast<TestRecursiveType>();
printer << "test_rec<" << rec.getName();
if (!stack.contains(rec)) {

View File

@ -734,6 +734,8 @@ void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
/// The code block for default attribute parser/printer dispatch boilerplate.
/// {0}: the dialect fully qualified class name.
/// {1}: the optional code for the dynamic attribute parser dispatch.
/// {2}: the optional code for the dynamic attribute printer dispatch.
static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
/// Parse an attribute registered to this dialect.
::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
@ -748,6 +750,7 @@ static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
if (parseResult.hasValue())
return attr;
}
{1}
parser.emitError(typeLoc) << "unknown attribute `"
<< attrTag << "` in dialect `" << getNamespace() << "`";
return {{};
@ -757,11 +760,33 @@ void {0}::printAttribute(::mlir::Attribute attr,
::mlir::DialectAsmPrinter &printer) const {{
if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
return;
{2}
}
)";
/// The code block for dynamic attribute parser dispatch boilerplate.
static const char *const dialectDynamicAttrParserDispatch = R"(
{
::mlir::Attribute genAttr;
auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr);
if (parseResult.hasValue()) {
if (::mlir::succeeded(parseResult.getValue()))
return genAttr;
return Attribute();
}
}
)";
/// The code block for dynamic type printer dispatch boilerplate.
static const char *const dialectDynamicAttrPrinterDispatch = R"(
if (::mlir::succeeded(printIfDynamicAttr(attr, printer)))
return;
)";
/// The code block for default type parser/printer dispatch boilerplate.
/// {0}: the dialect fully qualified class name.
/// {1}: the optional code for the dynamic type parser dispatch.
/// {2}: the optional code for the dynamic type printer dispatch.
static const char *const dialectDefaultTypePrinterParserDispatch = R"(
/// Parse a type registered to this dialect.
::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
@ -773,6 +798,7 @@ static const char *const dialectDefaultTypePrinterParserDispatch = R"(
auto parseResult = generatedTypeParser(parser, mnemonic, genType);
if (parseResult.hasValue())
return genType;
{1}
parser.emitError(typeLoc) << "unknown type `"
<< mnemonic << "` in dialect `" << getNamespace() << "`";
return {{};
@ -782,9 +808,28 @@ void {0}::printType(::mlir::Type type,
::mlir::DialectAsmPrinter &printer) const {{
if (::mlir::succeeded(generatedTypePrinter(type, printer)))
return;
{2}
}
)";
/// The code block for dynamic type parser dispatch boilerplate.
static const char *const dialectDynamicTypeParserDispatch = R"(
{
auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
if (parseResult.hasValue()) {
if (::mlir::succeeded(parseResult.getValue()))
return genType;
return Type();
}
}
)";
/// The code block for dynamic type printer dispatch boilerplate.
static const char *const dialectDynamicTypePrinterDispatch = R"(
if (::mlir::succeeded(printIfDynamicType(type, printer)))
return;
)";
/// Emit the dialect printer/parser dispatcher. User's code should call these
/// functions from their dialect's print/parse methods.
void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
@ -880,16 +925,30 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
if (valueType == "Attribute" && needsDialectParserPrinter &&
firstDialect.useDefaultAttributePrinterParser()) {
NamespaceEmitter nsEmitter(os, firstDialect);
os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
firstDialect.getCppClassName());
if (firstDialect.isExtensible()) {
os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
firstDialect.getCppClassName(),
dialectDynamicAttrParserDispatch,
dialectDynamicAttrPrinterDispatch);
} else {
os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
firstDialect.getCppClassName(), "", "");
}
}
// Emit the default parser/printer for Types if the dialect asked for it.
if (valueType == "Type" && needsDialectParserPrinter &&
firstDialect.useDefaultTypePrinterParser()) {
NamespaceEmitter nsEmitter(os, firstDialect);
os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
firstDialect.getCppClassName());
if (firstDialect.isExtensible()) {
os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
firstDialect.getCppClassName(),
dialectDynamicTypeParserDispatch,
dialectDynamicTypePrinterDispatch);
} else {
os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
firstDialect.getCppClassName(), "", "");
}
}
return false;

View File

@ -90,9 +90,9 @@ findSelectedDialect(ArrayRef<const llvm::Record *> dialectDefs) {
/// {2}: initialization code that is emitted in the ctor body before calling
/// initialize()
static const char *const dialectDeclBeginStr = R"(
class {0} : public ::mlir::Dialect {
class {0} : public ::mlir::{3} {
explicit {0}(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context,
: ::mlir::{3}(getDialectNamespace(), context,
::mlir::TypeID::get<{0}>()) {{
{2}
initialize();
@ -205,8 +205,10 @@ emitDialectDecl(Dialect &dialect,
// Emit the start of the decl.
std::string cppName = dialect.getCppClassName();
StringRef superClassName =
dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
dependentDialectRegistrations);
dependentDialectRegistrations, superClassName);
// Check for any attributes/types registered to this dialect. If there are,
// add the hooks for parsing/printing.