[mlir][IR] Define the singleton builtin types in ODS instead of C++

This exposes several issues with the current generation that this revision also fixes.
 * TypeDef now allows specifying the base class to use when generating.
 * TypeDef now inherits from DialectType, which allows for using it as a TypeConstraint
 * Parser/Printers are now no longer generated in the header(removing duplicate symbols), and are now only generated when necessary.
    - Now that generatedTypeParser/Printer are only generated in the definition file,
      existing users will need to manually expose this functionality when necessary.
 * ::get() is no longer generated for singleton types, because it isn't necessary.

Differential Revision: https://reviews.llvm.org/D93270
This commit is contained in:
River Riddle 2020-12-15 13:39:09 -08:00
parent e113317958
commit 95019de8a1
16 changed files with 423 additions and 304 deletions

View File

@ -1370,10 +1370,10 @@ llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t value) {
## Type Definitions
MLIR defines the TypeDef class hierarchy to enable generation of data types
from their specifications. A type is defined by specializing the TypeDef
class with concrete contents for all the fields it requires. For example, an
integer type could be defined as:
MLIR defines the TypeDef class hierarchy to enable generation of data types from
their specifications. A type is defined by specializing the TypeDef class with
concrete contents for all the fields it requires. For example, an integer type
could be defined as:
```tablegen
// All of the types will extend this class.
@ -1414,45 +1414,43 @@ def IntegerType : Test_Type<"TestInteger"> {
### Type name
The name of the C++ class which gets generated defaults to
`<classParamName>Type` (e.g. `TestIntegerType` in the above example). This
can be overridden via the `cppClassName` field. The field `mnemonic` is
to specify the asm name for parsing. It is optional and not specifying it
will imply that no parser or printer methods are attached to this class.
`<classParamName>Type` (e.g. `TestIntegerType` in the above example). This can
be overridden via the `cppClassName` field. The field `mnemonic` is to specify
the asm name for parsing. It is optional and not specifying it will imply that
no parser or printer methods are attached to this class.
### Type documentation
The `summary` and `description` fields exist and are to be used the same way
as in Operations. Namely, the summary should be a one-liner and `description`
The `summary` and `description` fields exist and are to be used the same way as
in Operations. Namely, the summary should be a one-liner and `description`
should be a longer explanation.
### Type parameters
The `parameters` field is a list of the types parameters. If no parameters
are specified (the default), this type is considered a singleton type.
Parameters are in the `"c++Type":$paramName` format.
To use C++ types as parameters which need allocation in the storage
constructor, there are two options:
The `parameters` field is a list of the types parameters. If no parameters are
specified (the default), this type is considered a singleton type. Parameters
are in the `"c++Type":$paramName` format. To use C++ types as parameters which
need allocation in the storage constructor, there are two options:
- Set `hasCustomStorageConstructor` to generate the TypeStorage class with
a constructor which is just declared -- no definition -- so you can write it
yourself.
- Use the `TypeParameter` tablegen class instead of the "c++Type" string.
- Set `hasCustomStorageConstructor` to generate the TypeStorage class with a
constructor which is just declared -- no definition -- so you can write it
yourself.
- Use the `TypeParameter` tablegen class instead of the "c++Type" string.
### TypeParameter tablegen class
This is used to further specify attributes about each of the types
parameters. It includes documentation (`description` and `syntax`), the C++
type to use, and a custom allocator to use in the storage constructor method.
This is used to further specify attributes about each of the types parameters.
It includes documentation (`description` and `syntax`), the C++ type to use, and
a custom allocator to use in the storage constructor method.
```tablegen
// DO NOT DO THIS!
let parameters = (ins
"ArrayRef<int>":$dims);
let parameters = (ins "ArrayRef<int>":$dims);
```
The default storage constructor blindly copies fields by value. It does not
know anything about the types. In this case, the ArrayRef<int> requires
allocation with `dims = allocator.copyInto(dims)`.
The default storage constructor blindly copies fields by value. It does not know
anything about the types. In this case, the ArrayRef<int> requires allocation
with `dims = allocator.copyInto(dims)`.
You can specify the necessary constructor by specializing the `TypeParameter`
tblgen class:
@ -1460,28 +1458,29 @@ tblgen class:
```tablegen
class ArrayRefIntParam :
TypeParameter<"::llvm::ArrayRef<int>", "Array of ints"> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
let allocator = "$_dst = $_allocator.copyInto($_self);";
}
...
let parameters = (ins
ArrayRefIntParam:$dims);
let parameters = (ins ArrayRefIntParam:$dims);
```
The `allocator` code block has the following substitutions:
- `$_allocator` is the TypeStorageAllocator in which to allocate objects.
- `$_dst` is the variable in which to place the allocated data.
- `$_allocator` is the TypeStorageAllocator in which to allocate objects.
- `$_dst` is the variable in which to place the allocated data.
MLIR includes several specialized classes for common situations:
- `StringRefParameter<descriptionOfParam>` for StringRefs.
- `ArrayRefParameter<arrayOf, descriptionOfParam>` for ArrayRefs of value
types
- `SelfAllocationParameter<descriptionOfParam>` for C++ classes which contain
a method called `allocateInto(StorageAllocator &allocator)` to allocate
itself into `allocator`.
- `ArrayRefOfSelfAllocationParameter<arrayOf, descriptionOfParam>` for arrays
of objects which self-allocate as per the last specialization.
- `StringRefParameter<descriptionOfParam>` for StringRefs.
- `ArrayRefParameter<arrayOf, descriptionOfParam>` for ArrayRefs of value
types
- `SelfAllocationParameter<descriptionOfParam>` for C++ classes which contain
a method called `allocateInto(StorageAllocator &allocator)` to allocate
itself into `allocator`.
- `ArrayRefOfSelfAllocationParameter<arrayOf, descriptionOfParam>` for arrays
of objects which self-allocate as per the last specialization.
If we were to use one of these included specializations:
@ -1495,45 +1494,46 @@ let parameters = (ins
If a mnemonic is specified, the `printer` and `parser` code fields are active.
The rules for both are:
- If null, generate just the declaration.
- If non-null and non-empty, use the code in the definition. The `$_printer`
or `$_parser` substitutions are valid and should be used.
- It is an error to have an empty code block.
For each dialect, two "dispatch" functions will be created: one for parsing
and one for printing. You should add calls to these in your
`Dialect::printType` and `Dialect::parseType` methods. They are created in
the dialect's namespace and their function signatures are:
- If null, generate just the declaration.
- If non-null and non-empty, use the code in the definition. The `$_printer`
or `$_parser` substitutions are valid and should be used.
- It is an error to have an empty code block.
For each dialect, two "dispatch" functions will be created: one for parsing and
one for printing. You should add calls to these in your `Dialect::printType` and
`Dialect::parseType` methods. They are static functions placed alongside the
type class definitions and have the following function signatures:
```c++
Type generatedTypeParser(MLIRContext* ctxt, DialectAsmParser& parser,
StringRef mnemonic);
static Type generatedTypeParser(MLIRContext* ctxt, DialectAsmParser& parser, StringRef mnemonic);
LogicalResult generatedTypePrinter(Type type, DialectAsmPrinter& printer);
```
The mnemonic, parser, and printer fields are optional. If they're not
defined, the generated code will not include any parsing or printing code and
omit the type from the dispatch functions above. In this case, the dialect
author is responsible for parsing/printing the types in `Dialect::printType`
and `Dialect::parseType`.
The mnemonic, parser, and printer fields are optional. If they're not defined,
the generated code will not include any parsing or printing code and omit the
type from the dispatch functions above. In this case, the dialect author is
responsible for parsing/printing the types in `Dialect::printType` and
`Dialect::parseType`.
### Other fields
- If the `genStorageClass` field is set to 1 (the default) a storage class is
generated with member variables corresponding to each of the specified
`parameters`.
- If the `genAccessors` field is 1 (the default) accessor methods will be
generated on the Type class (e.g. `int getWidth() const` in the example
above).
- If the `genVerifyInvariantsDecl` field is set, a declaration for a method
`static LogicalResult verifyConstructionInvariants(Location, parameters...)`
is added to the class as well as a `getChecked(Location, parameters...)`
method which gets the result of `verifyConstructionInvariants` before calling
`get`.
- The `storageClass` field can be used to set the name of the storage class.
- The `storageNamespace` field is used to set the namespace where the storage
class should sit. Defaults to "detail".
- The `extraClassDeclaration` field is used to include extra code in the
class declaration.
- If the `genStorageClass` field is set to 1 (the default) a storage class is
generated with member variables corresponding to each of the specified
`parameters`.
- If the `genAccessors` field is 1 (the default) accessor methods will be
generated on the Type class (e.g. `int getWidth() const` in the example
above).
- If the `genVerifyInvariantsDecl` field is set, a declaration for a method
`static LogicalResult verifyConstructionInvariants(Location, parameters...)`
is added to the class as well as a `getChecked(Location, parameters...)`
method which gets the result of `verifyConstructionInvariants` before
calling `get`.
- The `storageClass` field can be used to set the name of the storage class.
- The `storageNamespace` field is used to set the namespace where the storage
class should sit. Defaults to "detail".
- The `extraClassDeclaration` field is used to include extra code in the class
declaration.
## Debugging Tips

View File

@ -0,0 +1,27 @@
//===-- BuiltinDialect.td - Builtin dialect definition -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the definition of the Builtin dialect. This dialect
// contains all of the attributes, operations, and types that are core to MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef BUILTIN_BASE
#define BUILTIN_BASE
include "mlir/IR/OpBase.td"
def Builtin_Dialect : Dialect {
let summary =
"A dialect containing the builtin Attributes, Operations, and Types";
let name = "";
let cppNamespace = "::mlir";
}
#endif // BUILTIN_BASE

View File

@ -14,17 +14,10 @@
#ifndef BUILTIN_OPS
#define BUILTIN_OPS
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
def Builtin_Dialect : Dialect {
let summary =
"A dialect containing the builtin Attributes, Operations, and Types";
let name = "";
let cppNamespace = "::mlir";
}
// Base class for Builtin dialect ops.
class Builtin_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Builtin_Dialect, mnemonic, traits>;

View File

@ -72,23 +72,6 @@ public:
Type getElementType();
};
//===----------------------------------------------------------------------===//
// IndexType
//===----------------------------------------------------------------------===//
/// Index is a special integer-like type with unknown platform-dependent bit
/// width.
class IndexType : public Type::TypeBase<IndexType, Type, TypeStorage> {
public:
using Base::Base;
/// Get an instance of the IndexType.
static IndexType get(MLIRContext *context);
/// Storage bit width used for IndexType by internal compiler data structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
};
//===----------------------------------------------------------------------===//
// IntegerType
//===----------------------------------------------------------------------===//
@ -187,67 +170,6 @@ public:
const llvm::fltSemantics &getFloatSemantics();
};
//===----------------------------------------------------------------------===//
// BFloat16Type
class BFloat16Type
: public Type::TypeBase<BFloat16Type, FloatType, TypeStorage> {
public:
using Base::Base;
/// Return an instance of the bfloat16 type.
static BFloat16Type get(MLIRContext *context);
};
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
return BFloat16Type::get(ctx);
}
//===----------------------------------------------------------------------===//
// Float16Type
class Float16Type : public Type::TypeBase<Float16Type, FloatType, TypeStorage> {
public:
using Base::Base;
/// Return an instance of the float16 type.
static Float16Type get(MLIRContext *context);
};
inline FloatType FloatType::getF16(MLIRContext *ctx) {
return Float16Type::get(ctx);
}
//===----------------------------------------------------------------------===//
// Float32Type
class Float32Type : public Type::TypeBase<Float32Type, FloatType, TypeStorage> {
public:
using Base::Base;
/// Return an instance of the float32 type.
static Float32Type get(MLIRContext *context);
};
inline FloatType FloatType::getF32(MLIRContext *ctx) {
return Float32Type::get(ctx);
}
//===----------------------------------------------------------------------===//
// Float64Type
class Float64Type : public Type::TypeBase<Float64Type, FloatType, TypeStorage> {
public:
using Base::Base;
/// Return an instance of the float64 type.
static Float64Type get(MLIRContext *context);
};
inline FloatType FloatType::getF64(MLIRContext *ctx) {
return Float64Type::get(ctx);
}
//===----------------------------------------------------------------------===//
// FunctionType
//===----------------------------------------------------------------------===//
@ -276,20 +198,6 @@ public:
ArrayRef<unsigned> resultIndices);
};
//===----------------------------------------------------------------------===//
// NoneType
//===----------------------------------------------------------------------===//
/// NoneType is a unit type, i.e. a type with exactly one possible value, where
/// its value does not have a defined dynamic representation.
class NoneType : public Type::TypeBase<NoneType, Type, TypeStorage> {
public:
using Base::Base;
/// Get an instance of the NoneType.
static NoneType get(MLIRContext *context);
};
//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//
@ -720,11 +628,20 @@ public:
return getTypes()[index];
}
};
} // end namespace mlir
//===----------------------------------------------------------------------===//
// Tablegen Type Declarations
//===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"
//===----------------------------------------------------------------------===//
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
namespace mlir {
inline bool BaseMemRefType::classof(Type type) {
return type.isa<MemRefType, UnrankedMemRefType>();
}
@ -733,6 +650,22 @@ inline bool FloatType::classof(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
}
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
return BFloat16Type::get(ctx);
}
inline FloatType FloatType::getF16(MLIRContext *ctx) {
return Float16Type::get(ctx);
}
inline FloatType FloatType::getF32(MLIRContext *ctx) {
return Float32Type::get(ctx);
}
inline FloatType FloatType::getF64(MLIRContext *ctx) {
return Float64Type::get(ctx);
}
inline bool ShapedType::classof(Type type) {
return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
UnrankedMemRefType, MemRefType>();

View File

@ -0,0 +1,114 @@
//===- BuiltinTypes.td - Builtin type definitions ----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the set of builtin MLIR types, or the set of types necessary for the
// validity of and defining the IR.
//
//===----------------------------------------------------------------------===//
#ifndef BUILTIN_TYPES
#define BUILTIN_TYPES
include "mlir/IR/BuiltinDialect.td"
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
// This is to differentiate the types here with the ones in OpBase.td. We should
// remove the definitions in OpBase.td, and repoint users to this file instead.
// Base class for Builtin dialect types.
class Builtin_Type<string name> : TypeDef<Builtin_Dialect, name> {
let mnemonic = ?;
}
//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//
// Base class for Builtin dialect float types.
class Builtin_FloatType<string name> : TypeDef<Builtin_Dialect, name,
"::mlir::FloatType"> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
}
//===----------------------------------------------------------------------===//
// BFloat16Type
def Builtin_BFloat16 : Builtin_FloatType<"BFloat16"> {
let summary = "bfloat16 floating-point type";
}
//===----------------------------------------------------------------------===//
// Float16Type
def Builtin_Float16 : Builtin_FloatType<"Float16"> {
let summary = "16-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// Float32Type
def Builtin_Float32 : Builtin_FloatType<"Float32"> {
let summary = "32-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// Float64Type
def Builtin_Float64 : Builtin_FloatType<"Float64"> {
let summary = "64-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// IndexType
//===----------------------------------------------------------------------===//
def Builtin_Index : Builtin_Type<"Index"> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
```
// Target word-sized integer.
index-type ::= `index`
```
The index type is a signless integer whose size is equal to the natural
machine word of the target ( [rationale](https://mlir.llvm.org/docs/Rationale/Rationale/#integer-signedness-semantics) )
and is used by the affine constructs in MLIR. Unlike fixed-size integers,
it cannot be used as an element of vector ( [rationale](https://mlir.llvm.org/docs/Rationale/Rationale/#index-type-disallowed-in-vector-types) ).
**Rationale:** integers of platform-specific bit widths are practical to
express sizes, dimensionalities and subscripts.
}];
let extraClassDeclaration = [{
static IndexType get(MLIRContext *context);
/// Storage bit width used for IndexType by internal compiler data
/// structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
}];
}
//===----------------------------------------------------------------------===//
// NoneType
//===----------------------------------------------------------------------===//
def Builtin_None : Builtin_Type<"None"> {
let summary = "A unit type";
let description = [{
NoneType is a unit type, i.e. a type with exactly one possible value, where
its value does not have a defined dynamic representation.
}];
let extraClassDeclaration = [{
static NoneType get(MLIRContext *context);
}];
}
#endif // BUILTIN_TYPES

View File

@ -2,10 +2,18 @@ add_mlir_interface(OpAsmInterface)
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)
set(LLVM_TARGET_DEFINITIONS BuiltinDialect.td)
mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls)
add_public_tablegen_target(MLIRBuiltinDialectIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinOps.td)
mlir_tablegen(BuiltinOps.h.inc -gen-op-decls)
mlir_tablegen(BuiltinOps.cpp.inc -gen-op-defs)
mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls)
add_public_tablegen_target(MLIRBuiltinOpsIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
add_mlir_doc(BuiltinOps -gen-op-doc Builtin Dialects/)

View File

@ -2415,15 +2415,18 @@ def replaceWithValue;
// Data type generation
//===----------------------------------------------------------------------===//
// Define a new type belonging to a dialect and called 'name'.
class TypeDef<Dialect owningdialect, string name> {
Dialect dialect = owningdialect;
// Define a new type, named `name`, belonging to `dialect` that inherits from
// the given C++ base class.
class TypeDef<Dialect dialect, string name,
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">> {
// The name of the C++ Type class.
string cppClassName = name # "Type";
// The name of the C++ base class to use for this Type.
string cppBaseClassName = baseCppClass;
// Short summary of the type.
string summary = ?;
// The longer description of this type.
string description = ?;
// Name of storage class to generate or use.
string storageClass = name # "TypeStorage";
@ -2477,6 +2480,15 @@ class TypeDef<Dialect owningdialect, string name> {
bit genVerifyInvariantsDecl = 0;
// Extra code to include in the class declaration.
code extraClassDeclaration = [{}];
// The predicate for when this type is used as a type constraint.
let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
"::" # cppClassName # ">()">;
// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #
"::" # cppClassName # ">()",
"");
}
// 'Parameters' should be subclasses of this or simple strings (which is a

View File

@ -48,6 +48,9 @@ public:
// Returns the name of the C++ class to generate.
StringRef getCppClassName() const;
// Returns the name of the C++ base class to use when generating this type.
StringRef getCppBaseClassName() const;
// Returns the name of the storage class for this type.
StringRef getStorageClassName() const;

View File

@ -20,6 +20,13 @@
using namespace mlir;
using namespace mlir::detail;
//===----------------------------------------------------------------------===//
/// Tablegen Type Definitions
//===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.cpp.inc"
//===----------------------------------------------------------------------===//
/// ComplexType
//===----------------------------------------------------------------------===//

View File

@ -33,7 +33,9 @@ add_mlir_library(MLIRIR
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
DEPENDS
MLIRBuiltinDialectIncGen
MLIRBuiltinOpsIncGen
MLIRBuiltinTypesIncGen
MLIRCallInterfacesIncGen
MLIROpAsmInterfaceIncGen
MLIRRegionKindInterfaceIncGen

View File

@ -13,6 +13,7 @@
#include "mlir/TableGen/Constraint.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
Constraint::Constraint(const llvm::Record *record)
@ -56,11 +57,18 @@ std::string Constraint::getConditionTemplate() const {
return getPredicate().getCondition();
}
llvm::StringRef Constraint::getDescription() const {
auto doc = def->getValueAsString("description");
if (doc.empty())
return def->getName();
return doc;
StringRef Constraint::getDescription() const {
// If a summary is found, we use that given that it is a focused single line
// comment.
if (Optional<StringRef> summary = def->getValueAsOptionalString("summary"))
return *summary;
// If a summary can't be found, look for a specific description field to use
// for the constraint.
StringRef desc = def->getValueAsString("description");
if (!desc.empty())
return desc;
// Otherwise, fallback to the name of the constraint definition.
return def->getName();
}
AppliedConstraint::AppliedConstraint(Constraint &&constraint,

View File

@ -31,6 +31,10 @@ StringRef TypeDef::getCppClassName() const {
return def->getValueAsString("cppClassName");
}
StringRef TypeDef::getCppBaseClassName() const {
return def->getValueAsString("cppBaseClassName");
}
bool TypeDef::hasDescription() const {
const llvm::RecordVal *s = def->getValue("description");
return s != nullptr && isa<llvm::StringInit>(s->getValue());

View File

@ -15,7 +15,6 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringSwitch.h"
using namespace mlir;
@ -183,77 +182,6 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
return builder.create<TestOpConstant>(loc, type, value);
}
static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
llvm::SetVector<Type> &stack) {
StringRef typeTag;
if (failed(parser.parseKeyword(&typeTag)))
return Type();
auto genType = generatedTypeParser(ctxt, parser, typeTag);
if (genType != Type())
return genType;
if (typeTag == "test_type")
return TestType::get(parser.getBuilder().getContext());
if (typeTag != "test_rec")
return Type();
StringRef name;
if (parser.parseLess() || parser.parseKeyword(&name))
return Type();
auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
// If this type already has been parsed above in the stack, expect just the
// name.
if (stack.contains(rec)) {
if (failed(parser.parseGreater()))
return Type();
return rec;
}
// Otherwise, parse the body and update the type.
if (failed(parser.parseComma()))
return Type();
stack.insert(rec);
Type subtype = parseTestType(ctxt, parser, stack);
stack.pop_back();
if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
return Type();
return rec;
}
Type TestDialect::parseType(DialectAsmParser &parser) const {
llvm::SetVector<Type> stack;
return parseTestType(getContext(), parser, stack);
}
static void printTestType(Type type, DialectAsmPrinter &printer,
llvm::SetVector<Type> &stack) {
if (succeeded(generatedTypePrinter(type, printer)))
return;
if (type.isa<TestType>()) {
printer << "test_type";
return;
}
auto rec = type.cast<TestRecursiveType>();
printer << "test_rec<" << rec.getName();
if (!stack.contains(rec)) {
printer << ", ";
stack.insert(rec);
printTestType(rec.getBody(), printer, stack);
stack.pop_back();
}
printer << ">";
}
void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
llvm::SetVector<Type> stack;
printTestType(type, printer, stack);
}
LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr")

View File

@ -12,9 +12,12 @@
//===----------------------------------------------------------------------===//
#include "TestTypes.h"
#include "TestDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@ -116,5 +119,84 @@ LogicalResult TestIntegerType::verifyConstructionInvariants(
return success();
}
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.cpp.inc"
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
llvm::SetVector<Type> &stack) {
StringRef typeTag;
if (failed(parser.parseKeyword(&typeTag)))
return Type();
auto genType = generatedTypeParser(ctxt, parser, typeTag);
if (genType != Type())
return genType;
if (typeTag == "test_type")
return TestType::get(parser.getBuilder().getContext());
if (typeTag != "test_rec")
return Type();
StringRef name;
if (parser.parseLess() || parser.parseKeyword(&name))
return Type();
auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
// If this type already has been parsed above in the stack, expect just the
// name.
if (stack.contains(rec)) {
if (failed(parser.parseGreater()))
return Type();
return rec;
}
// Otherwise, parse the body and update the type.
if (failed(parser.parseComma()))
return Type();
stack.insert(rec);
Type subtype = parseTestType(ctxt, parser, stack);
stack.pop_back();
if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
return Type();
return rec;
}
Type TestDialect::parseType(DialectAsmParser &parser) const {
llvm::SetVector<Type> stack;
return parseTestType(getContext(), parser, stack);
}
static void printTestType(Type type, DialectAsmPrinter &printer,
llvm::SetVector<Type> &stack) {
if (succeeded(generatedTypePrinter(type, printer)))
return;
if (type.isa<TestType>()) {
printer << "test_type";
return;
}
auto rec = type.cast<TestRecursiveType>();
printer << "test_rec<" << rec.getName();
if (!stack.contains(rec)) {
printer << ", ";
stack.insert(rec);
printTestType(rec.getBody(), printer, stack);
stack.pop_back();
}
printer << ">";
}
void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
llvm::SetVector<Type> stack;
printTestType(type, printer, stack);
}

View File

@ -11,9 +11,6 @@ include "mlir/IR/OpBase.td"
// DECL: class DialectAsmPrinter;
// DECL: } // namespace mlir
// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);
// DECL: ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer);
// DEF: #ifdef GET_TYPEDEF_LIST
// DEF: #undef GET_TYPEDEF_LIST
// DEF: ::mlir::test::SimpleAType,

View File

@ -92,7 +92,7 @@ public:
/// llvm::formatv will call this function when using an instance as a
/// replacement value.
void format(raw_ostream &os, StringRef options) override {
if (params.size() && prependComma)
if (!params.empty() && prependComma)
os << ", ";
switch (emitFormat) {
@ -146,8 +146,9 @@ class DialectAsmPrinter;
/// case.
///
/// {0}: The name of the typeDef class.
/// {1}: The name of the type base class.
static const char *const typeDefDeclSingletonBeginStr = R"(
class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{
class {0}: public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
@ -158,15 +159,16 @@ static const char *const typeDefDeclSingletonBeginStr = R"(
/// case.
///
/// {0}: The name of the typeDef class.
/// {1}: The typeDef storage class namespace.
/// {2}: The storage class name.
/// {3}: The list of parameters with types.
/// {1}: The name of the type base class.
/// {2}: The typeDef storage class namespace.
/// {3}: The storage class name.
/// {4}: The list of parameters with types.
static const char *const typeDefDeclParametricBeginStr = R"(
namespace {1} {
struct {2};
namespace {2} {
struct {3};
}
class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type,
{1}::{2}> {{
class {0}: public ::mlir::Type::TypeBase<{0}, {1},
{2}::{3}> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
@ -196,10 +198,11 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
// template.
if (typeDef.getNumParameters() == 0)
os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(),
typeDef.getStorageNamespace(), typeDef.getStorageClassName());
typeDef.getCppBaseClassName());
else
os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(),
typeDef.getStorageNamespace(), typeDef.getStorageClassName());
typeDef.getCppBaseClassName(), typeDef.getStorageNamespace(),
typeDef.getStorageClassName());
// Emit the extra declarations first in case there's a type definition in
// there.
@ -208,8 +211,10 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
TypeParamCommaFormatter emitTypeNamePairsAfterComma(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
if (!params.empty()) {
os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
}
// Emit the verify invariants declaration.
if (typeDef.genVerifyInvariantsDecl())
@ -252,17 +257,9 @@ static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper,
// Output the common "header".
os << typeDefDeclHeader;
if (typeDefs.size() > 0) {
if (!typeDefs.empty()) {
NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());
// Well known print/parse dispatch function declarations. These are called
// from Dialect::parseType() and Dialect::printType() methods.
os << " ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
"::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n";
os << " ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
"::mlir::DialectAsmPrinter& printer);\n";
os << "\n";
// Declare all the type classes first (in case they reference each other).
for (const TypeDef &typeDef : typeDefs)
os << " class " << typeDef.getCppClassName() << ";\n";
@ -488,14 +485,16 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0)
emitStorageClass(typeDef, os);
os << llvm::formatv(
"{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
" return Base::get(ctxt{2});\n}\n",
typeDef.getCppClassName(),
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
parameters));
if (!parameters.empty()) {
os << llvm::formatv(
"{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
" return Base::get(ctxt{2});\n}\n",
typeDef.getCppClassName(),
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
parameters));
}
// Emit the parameter accessors.
if (typeDef.genAccessors())
@ -526,38 +525,40 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
/// Emit the dialect printer/parser dispatcher. User's code should call these
/// functions from their dialect's print/parse methods.
static void emitParsePrintDispatch(SmallVectorImpl<TypeDef> &typeDefs,
raw_ostream &os) {
if (typeDefs.size() == 0)
static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
if (llvm::none_of(types, [](const TypeDef &type) {
return type.getMnemonic().hasValue();
})) {
return;
const Dialect &dialect = typeDefs.begin()->getDialect();
NamespaceEmitter ns(os, dialect);
}
// The parser dispatch is just a list of if-elses, matching on the mnemonic
// and calling the class's parse function.
os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
// The parser dispatch is just a list of if-elses, matching on the
// mnemonic and calling the class's parse function.
os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* "
"ctxt, "
"::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
for (const TypeDef &typeDef : typeDefs)
if (typeDef.getMnemonic())
for (const TypeDef &type : types)
if (type.getMnemonic())
os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return "
"{0}::{1}::parse(ctxt, parser);\n",
typeDef.getDialect().getCppNamespace(),
typeDef.getCppClassName());
type.getDialect().getCppNamespace(),
type.getCppClassName());
os << " return ::mlir::Type();\n";
os << "}\n\n";
// The printer dispatch uses llvm::TypeSwitch to find and call the correct
// printer.
os << "::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type "
"type, "
"::mlir::DialectAsmPrinter& printer) {\n"
<< " ::mlir::LogicalResult found = ::mlir::success();\n"
<< " ::llvm::TypeSwitch<::mlir::Type>(type)\n";
for (auto typeDef : typeDefs)
if (typeDef.getMnemonic())
for (const TypeDef &type : types)
if (type.getMnemonic())
os << formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ "
"t.dyn_cast<{0}::{1}>().print(printer); })\n",
typeDef.getDialect().getCppNamespace(),
typeDef.getCppClassName());
type.getDialect().getCppNamespace(),
type.getCppClassName());
os << " .Default([&found](::mlir::Type) { found = ::mlir::failure(); "
"});\n"
<< " return found;\n"