[mlir][TypeDefGen] Add support for adding builders when generating a TypeDef

This allows for specifying additional get/getChecked methods that should be generated on the type, and acts similarly to how OpBuilders work. TypeBuilders have two additional components though:
* InferredContextParam
  - Bit indicating that the context parameter of a get method is inferred from one of the builder parameters
* checkedBody
  - A code block representing the body of the equivalent getChecked method.

Differential Revision: https://reviews.llvm.org/D94274
This commit is contained in:
River Riddle 2021-01-11 11:55:00 -08:00
parent 2074177301
commit 948be58258
9 changed files with 541 additions and 95 deletions

View File

@ -1536,6 +1536,171 @@ responsible for parsing/printing the types in `Dialect::printType` and
- The `extraClassDeclaration` field is used to include extra code in the class
declaration.
### Type builder methods
For each type, there are a few builders(`get`/`getChecked`) automatically
generated based on the parameters of the type. For example, given the following
type definition:
```tablegen
def MyType : ... {
let parameters = (ins "int":$intParam);
}
```
The following builders are generated:
```c++
// Type builders are named `get`, and return a new instance of a type for a
// given set of parameters.
static MyType get(MLIRContext *context, int intParam);
// If `genVerifyInvariantsDecl` is set to 1, the following method is also
// generated.
static MyType getChecked(Location loc, int intParam);
```
If these autogenerated methods are not desired, such as when they conflict with
a custom builder method, a type can set `skipDefaultBuilders` to 1 to signal
that they should not be generated.
#### Custom type builder methods
The default build methods may cover a majority of the simple cases related to
type construction, but when they cannot satisfy a type's needs, you can define
additional convenience get methods in the `builders` field as follows:
```tablegen
def MyType : ... {
let parameters = (ins "int":$intParam);
let builders = [
TypeBuilder<(ins "int":$intParam)>,
TypeBuilder<(ins CArg<"int", "0">:$intParam)>,
TypeBuilder<(ins CArg<"int", "0">:$intParam), [{
// Write the body of the `get` builder inline here.
return Base::get($_ctxt, intParam);
}]>,
TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{
// This builder states that it can infer an MLIRContext instance from
// its arguments.
return Base::get(typeParam.getContext(), ...);
}]>,
];
}
```
The `builders` field is a list of custom builders that are added to the type
class. In this example, we provide a several different convenience builders that
are useful in different scenarios. The `ins` prefix is common to many function
declarations in ODS, which use a TableGen [`dag`](#tablegen-syntax). What
follows is a comma-separated list of types (quoted string or CArg) and names
prefixed with the `$` sign. The use of `CArg` allows for providing a default
value to that argument. Let's take a look at each of these builders individually
The first builder will generate the declaration of a builder method that looks
like:
```tablegen
let builders = [
TypeBuilder<(ins "int":$intParam)>,
];
```
```c++
class MyType : /*...*/ {
/*...*/
static MyType get(::mlir::MLIRContext *context, int intParam);
};
```
This builder is identical to the one that will be automatically generated for
`MyType`. The `context` parameter is implicitly added by the generator, and is
used when building the file Type instance (with `Base::get`). The distinction
here is that we can provide the implementation of this `get` method. With this
style of builder definition only the declaration is generated, the implementor
of MyType will need to provide a definition of `MyType::get`.
The second builder will generate the declaration of a builder method that looks
like:
```tablegen
let builders = [
TypeBuilder<(ins CArg<"int", "0">:$intParam)>,
];
```
```c++
class MyType : /*...*/ {
/*...*/
static MyType get(::mlir::MLIRContext *context, int intParam = 0);
};
```
The constraints here are identical to the first builder example except for the
fact that `intParam` now has a default value attached.
The third builder will generate the declaration of a builder method that looks
like:
```tablegen
let builders = [
TypeBuilder<(ins CArg<"int", "0">:$intParam), [{
// Write the body of the `get` builder inline here.
return Base::get($_ctxt, intParam);
}]>,
];
```
```c++
class MyType : /*...*/ {
/*...*/
static MyType get(::mlir::MLIRContext *context, int intParam = 0);
};
MyType MyType::get(::mlir::MLIRContext *context, int intParam) {
// Write the body of the `get` builder inline here.
return Base::get(context, intParam);
}
```
This is identical to the second builder example. The difference is that now, a
definition for the builder method will be generated automatically using the
provided code block as the body. When specifying the body inline, `$_ctxt` may
be used to access the `MLIRContext *` parameter.
The fourth builder will generate the declaration of a builder method that looks
like:
```tablegen
let builders = [
TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{
// This builder states that it can infer an MLIRContext instance from
// its arguments.
return Base::get(typeParam.getContext(), ...);
}]>,
];
```
```c++
class MyType : /*...*/ {
/*...*/
static MyType get(Type typeParam);
};
MyType MyType::get(Type typeParam) {
// This builder states that it can infer an MLIRContext instance from its
// arguments.
return Base::get(typeParam.getContext(), ...);
}
```
In this builder example, the main difference from the third builder example
three is that the `MLIRContext` parameter is no longer added. This is because
the builder type used `TypeBuilderWithInferredContext` implies that the context
parameter is not necessary as it can be inferred from the arguments to the
builder.
## Debugging Tips
### Run `mlir-tblgen` to see the generated content

View File

@ -74,7 +74,7 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
VectorType vector;
if ($_parser.parseType(vector))
return Type();
return get(ctxt, vector.getShape(), vector.getElementType());
return get($_ctxt, vector.getShape(), vector.getElementType());
}];
let extraClassDeclaration = [{

View File

@ -2430,6 +2430,73 @@ def replaceWithValue;
// Data type generation
//===----------------------------------------------------------------------===//
// Class for defining a custom type getter.
//
// TableGen generates several generic getter methods for each type by default,
// corresponding to the specified dag parameters. If the default generated ones
// cannot cover some use case, custom getters can be defined using instances of
// this class.
//
// The signature of the `get` is always either:
//
// ```c++
// static <Type-Name> get(MLIRContext *context, <other-parameters>...) {
// <body>...
// }
// ```
//
// or:
//
// ```c++
// static <TypeName> get(MLIRContext *context, <parameters>...);
// ```
//
// To define a custom getter, the parameter list and body should be passed
// in as separate template arguments to this class. The parameter list is a
// TableGen DAG with `ins` operation with named arguments, which has either:
// - string initializers ("Type":$name) to represent a typed parameter, or
// - CArg-typed initializers (CArg<"Type", "default">:$name) to represent a
// typed parameter that may have a default value.
// The type string is used verbatim to produce code and, therefore, must be a
// valid C++ type. It is used inside the C++ namespace of the parent Type's
// dialect; explicit namespace qualification like `::mlir` may be necessary if
// Types are not placed inside the `mlir` namespace. The default value string is
// used verbatim to produce code and must be a valid C++ initializer the given
// type. For example, the following signature specification
//
// ```
// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)>
// ```
//
// has an integer parameter and a float parameter with a default value.
//
// If an empty string is passed in for `body`, then *only* the builder
// declaration will be generated; this provides a way to define complicated
// builders entirely in C++.
//
// `checkedBody` is similar to `body`, but is the code block used when
// generating a `getChecked` method.
class TypeBuilder<dag parameters, code bodyCode = "",
code checkedBodyCode = ""> {
dag dagParams = parameters;
code body = bodyCode;
code checkedBody = checkedBodyCode;
// The context parameter can be inferred from one of the other parameters and
// is not implicitly added to the parameter list.
bit hasInferredContextParam = 0;
}
// A class of TypeBuilder that is able to infer the MLIRContext parameter from
// one of the other builder parameters. Instances of this builder do not have
// `MLIRContext *` implicitly added to the parameter list.
class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
code checkedBodyCode = "">
: TypeBuilder<parameters, bodyCode> {
code checkedBody = checkedBodyCode;
let hasInferredContextParam = 1;
}
// Define a new type, named `name`, belonging to `dialect` that inherits from
// the given C++ base class.
class TypeDef<Dialect dialect, string name,
@ -2475,6 +2542,18 @@ class TypeDef<Dialect dialect, string name,
// for re-allocating ArrayRefs. It is defined below.)
dag parameters = (ins);
// Custom type builder methods.
// In addition to the custom builders provided here, and unless
// skipDefaultBuilders is set, a default builder is generated with the
// following signature:
//
// ```c++
// static <TypeName> get(MLIRContext *, <parameters>);
// ```
//
// Note that builders should only be provided when a type has parameters.
list<TypeBuilder> builders = ?;
// Use the lowercased name as the keyword for parsing/printing. Specify only
// if you want tblgen to generate declarations and/or definitions of
// printer/parser for this type.
@ -2488,6 +2567,9 @@ class TypeDef<Dialect dialect, string name,
// If set, generate accessors for each Type parameter.
bit genAccessors = 1;
// Avoid generating default get/getChecked functions. Custom get methods must
// be provided.
bit skipDefaultBuilders = 0;
// Generate the verifyConstructionInvariants declaration and getChecked
// method.
bit genVerifyInvariantsDecl = 0;

View File

@ -14,24 +14,45 @@
#define MLIR_TABLEGEN_TYPEDEF_H
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/Builder.h"
namespace llvm {
class Record;
class DagInit;
class Record;
class SMLoc;
} // namespace llvm
namespace mlir {
namespace tblgen {
class Dialect;
class TypeParameter;
//===----------------------------------------------------------------------===//
// TypeBuilder
//===----------------------------------------------------------------------===//
/// Wrapper class that represents a Tablegen TypeBuilder.
class TypeBuilder : public Builder {
public:
using Builder::Builder;
/// Return an optional code body used for the `getChecked` variant of this
/// builder.
Optional<StringRef> getCheckedBody() const;
/// Returns true if this builder is able to infer the MLIRContext parameter.
bool hasInferredContextParameter() const;
};
//===----------------------------------------------------------------------===//
// TypeDef
//===----------------------------------------------------------------------===//
/// Wrapper class that contains a TableGen TypeDef's record and provides helper
/// methods for accessing them.
class TypeDef {
public:
explicit TypeDef(const llvm::Record *def) : def(def) {}
explicit TypeDef(const llvm::Record *def);
// Get the dialect for which this type belongs.
Dialect getDialect() const;
@ -95,6 +116,13 @@ public:
// Get the code location (for error printing).
ArrayRef<llvm::SMLoc> getLoc() const;
// Returns true if the default get/getChecked methods should be skipped during
// generation.
bool skipDefaultBuilders() const;
// Returns the builders of this type.
ArrayRef<TypeBuilder> getBuilders() const { return builders; }
// Returns whether two TypeDefs are equal by checking the equality of the
// underlying record.
bool operator==(const TypeDef &other) const;
@ -107,8 +135,15 @@ public:
private:
const llvm::Record *def;
// The builders of this type definition.
SmallVector<TypeBuilder> builders;
};
//===----------------------------------------------------------------------===//
// TypeParameter
//===----------------------------------------------------------------------===//
// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs
// to parameterize them.
class TypeParameter {

View File

@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/TypeDef.h"
#include "mlir/TableGen/Dialect.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@ -18,6 +19,26 @@
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// TypeBuilder
//===----------------------------------------------------------------------===//
/// Return an optional code body used for the `getChecked` variant of this
/// builder.
Optional<StringRef> TypeBuilder::getCheckedBody() const {
Optional<StringRef> body = def->getValueAsOptionalString("checkedBody");
return body && !body->empty() ? body : llvm::None;
}
/// Returns true if this builder is able to infer the MLIRContext parameter.
bool TypeBuilder::hasInferredContextParameter() const {
return def->getValueAsBit("hasInferredContextParam");
}
//===----------------------------------------------------------------------===//
// TypeDef
//===----------------------------------------------------------------------===//
Dialect TypeDef::getDialect() const {
auto *dialectDef =
dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
@ -98,6 +119,11 @@ llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
return value.empty() ? llvm::Optional<StringRef>() : value;
}
llvm::ArrayRef<llvm::SMLoc> TypeDef::getLoc() const { return def->getLoc(); }
bool TypeDef::skipDefaultBuilders() const {
return def->getValueAsBit("skipDefaultBuilders");
}
bool TypeDef::operator==(const TypeDef &other) const {
return def == other.def;
}
@ -106,6 +132,33 @@ bool TypeDef::operator<(const TypeDef &other) const {
return getName() < other.getName();
}
//===----------------------------------------------------------------------===//
// TypeParameter
//===----------------------------------------------------------------------===//
TypeDef::TypeDef(const llvm::Record *def) : def(def) {
// Populate the builders.
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (llvm::Init *init : builderList->getValues()) {
TypeBuilder builder(cast<llvm::DefInit>(init)->getDef(), def->getLoc());
// Ensure that all parameters have names.
for (const TypeBuilder::Parameter &param : builder.getParameters()) {
if (!param.getName())
PrintFatalError(def->getLoc(),
"type builder parameters must have a name");
}
builders.emplace_back(builder);
}
} else if (skipDefaultBuilders()) {
PrintFatalError(
def->getLoc(),
"default builders are skipped and no custom builders provided");
}
}
StringRef TypeParameter::getName() const {
return def->getArgName(num)->getValue();
}

View File

@ -51,9 +51,9 @@ def IntegerType : Test_Type<"TestInteger"> {
let genVerifyInvariantsDecl = 1;
let parameters = (
ins
"unsigned":$width,
// SignednessSemantics is defined below.
"::mlir::test::TestIntegerType::SignednessSemantics":$signedness,
"unsigned":$width
"::mlir::test::TestIntegerType::SignednessSemantics":$signedness
);
// We define the printer inline.
@ -63,6 +63,17 @@ def IntegerType : Test_Type<"TestInteger"> {
$_printer << ", " << getImpl()->width << ">";
}];
// Define custom builder methods.
let builders = [
TypeBuilder<(ins "unsigned":$width,
CArg<"SignednessSemantics", "Signless">:$signedness), [{
return Base::get($_ctxt, width, signedness);
}], [{
return Base::getChecked($_loc, width, signedness);
}]>
];
let skipDefaultBuilders = 1;
// The parser is defined here also.
let parser = [{
if (parser.parseLess()) return Type();
@ -73,7 +84,7 @@ def IntegerType : Test_Type<"TestInteger"> {
if ($_parser.parseInteger(width)) return Type();
if ($_parser.parseGreater()) return Type();
Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
return getChecked(loc, signedness, width);
return getChecked(loc, width, signedness);
}];
// Any extra code one wants in the type's class declaration.
@ -85,9 +96,6 @@ def IntegerType : Test_Type<"TestInteger"> {
Unsigned, /// Unsigned integer
};
/// This extra function is necessary since it doesn't include signedness
static IntegerType getChecked(unsigned width, Location location);
/// Return true if this is a signless integer type.
bool isSignless() const { return getSignedness() == Signless; }
/// Return true if this is a signed integer type.

View File

@ -113,7 +113,7 @@ static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
// Example type validity checker.
LogicalResult TestIntegerType::verifyConstructionInvariants(
Location loc, TestIntegerType::SignednessSemantics ss, unsigned int width) {
Location loc, unsigned width, TestIntegerType::SignednessSemantics ss) {
if (width > 8)
return failure();
return success();

View File

@ -19,8 +19,8 @@ include "mlir/IR/OpBase.td"
// DEF: ::mlir::test::SingleParameterType,
// DEF: ::mlir::test::IntegerType
// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic)
// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser);
// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &parser, ::llvm::StringRef mnemonic)
// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(context, parser);
// DEF return ::mlir::Type();
def Test_Dialect: Dialect {
@ -57,10 +57,11 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
let genVerifyInvariantsDecl = 1;
// DECL-LABEL: class CompoundAType : public ::mlir::Type
// DECL: static CompoundAType getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::Type getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
// DECL: static ::mlir::Type parse(::mlir::MLIRContext *context,
// DECL-NEXT: ::mlir::DialectAsmParser &parser);
// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
// DECL: int getWidthOfSomething() const;
// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
@ -77,7 +78,8 @@ def C_IndexType : TestType<"Index"> {
// DECL-LABEL: class IndexType : public ::mlir::Type
// DECL: static ::llvm::StringRef getMnemonic() { return "index"; }
// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
// DECL: static ::mlir::Type parse(::mlir::MLIRContext *context,
// DECL-NEXT: ::mlir::DialectAsmParser &parser);
// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
}

View File

@ -166,7 +166,7 @@ static const char *const typeDefDeclSingletonBeginStr = R"(
static const char *const typeDefDeclParametricBeginStr = R"(
namespace {2} {
struct {3};
}
} // end namespace {2}
class {0} : public ::mlir::Type::TypeBase<{0}, {1},
{2}::{3}> {{
public:
@ -177,18 +177,68 @@ static const char *const typeDefDeclParametricBeginStr = R"(
/// The snippet for print/parse.
static const char *const typeDefParsePrint = R"(
static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
static ::mlir::Type parse(::mlir::MLIRContext *context,
::mlir::DialectAsmParser &parser);
void print(::mlir::DialectAsmPrinter &printer) const;
)";
/// The code block for the verifyConstructionInvariants and getChecked.
///
/// {0}: List of parameters, parameters style.
/// {0}: The name of the typeDef class.
/// {1}: List of parameters, parameters style.
static const char *const typeDefDeclVerifyStr = R"(
static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{0});
static ::mlir::Type getChecked(::mlir::Location loc{0});
static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{1});
)";
/// Emit the builders for the given type.
static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
TypeParamCommaFormatter &paramTypes) {
StringRef typeClass = typeDef.getCppClassName();
bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
if (!typeDef.skipDefaultBuilders()) {
os << llvm::formatv(
" static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
paramTypes);
if (genCheckedMethods) {
os << llvm::formatv(
" static {0} getChecked(::mlir::Location loc{1});\n", typeClass,
paramTypes);
}
}
// Generate the builders specified by the user.
for (const TypeBuilder &builder : typeDef.getBuilders()) {
std::string paramStr;
llvm::raw_string_ostream paramOS(paramStr);
llvm::interleaveComma(
builder.getParameters(), paramOS,
[&](const TypeBuilder::Parameter &param) {
// Note: TypeBuilder parameters are guaranteed to have names.
paramOS << param.getCppType() << " " << *param.getName();
if (Optional<StringRef> defaultParamValue = param.getDefaultValue())
paramOS << " = " << *defaultParamValue;
});
paramOS.flush();
// Generate the `get` variant of the builder.
os << " static " << typeClass << " get(";
if (!builder.hasInferredContextParameter()) {
os << "::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
}
os << paramStr << ");\n";
// Generate the `getChecked` variant of the builder.
if (genCheckedMethods) {
os << " static " << typeClass << " getChecked(::mlir::Location loc";
if (!paramStr.empty())
os << ", " << paramStr;
os << ");\n";
}
}
}
/// Generate the declaration for the given typeDef class.
static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> params;
@ -212,13 +262,13 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
TypeParamCommaFormatter emitTypeNamePairsAfterComma(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
if (!params.empty()) {
os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
}
emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma);
// Emit the verify invariants declaration.
if (typeDef.genVerifyInvariantsDecl())
os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
os << llvm::formatv(typeDefDeclVerifyStr, typeDef.getCppClassName(),
emitTypeNamePairsAfterComma);
}
// Emit the mnenomic, if specified.
if (auto mnenomic = typeDef.getMnemonic()) {
@ -226,6 +276,7 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
<< "\"; }\n";
// If mnemonic specified, emit print/parse declarations.
if (typeDef.getParserCode() || typeDef.getPrinterCode() || !params.empty())
os << typeDefParsePrint;
}
@ -330,17 +381,6 @@ static const char *const typeDefStorageClassConstructorReturn = R"(
}
)";
/// The code block for the getChecked definition.
///
/// {0}: List of parameters, parameters style.
/// {1}: C++ type class name.
/// {2}: Comma separated list of parameter names.
static const char *const typeDefDefGetCheckeStr = R"(
::mlir::Type {1}::getChecked(Location loc{0}) {{
return Base::getChecked(loc{2});
}
)";
/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> parameters;
@ -403,13 +443,13 @@ static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
parameters, /* prependComma */ false));
// 3) Emit the construct method.
if (typeDef.hasStorageCustomConstructor())
if (typeDef.hasStorageCustomConstructor()) {
// If user wants to build the storage constructor themselves, declare it
// here and then they can write the definition elsewhere.
os << " static " << typeDef.getStorageClassName()
<< " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
"&key);\n";
else {
} else {
// If not, autogenerate one.
// First, unbox the parameters.
@ -460,7 +500,7 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
if (auto parserCode = typeDef.getParserCode()) {
// The mnenomic must be defined so the dispatcher knows how to dispatch.
os << "::mlir::Type " << typeDef.getCppClassName()
<< "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& "
<< "::parse(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &"
"parser) "
"{\n";
if (*parserCode == "") {
@ -470,34 +510,104 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
": parser (if specified) must have non-empty code");
}
auto fmtCtxt =
FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "ctxt");
FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "context");
os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
}
}
/// Print all the typedef-specific definition code.
static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
NamespaceEmitter ns(os, typeDef.getDialect());
SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);
// Emit the storage class, if requested and necessary.
if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0)
emitStorageClass(typeDef, os);
if (!parameters.empty()) {
/// Emit the builders for the given type.
static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
ArrayRef<TypeParameter> typeDefParams) {
bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
StringRef typeClass = typeDef.getCppClassName();
if (!typeDef.skipDefaultBuilders()) {
os << llvm::formatv(
"{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
" return Base::get(ctxt{2});\n}\n",
typeDef.getCppClassName(),
"{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n"
" return Base::get(context{2});\n}\n",
typeClass,
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, typeDefParams),
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
parameters));
typeDefParams));
if (genCheckedMethods) {
os << llvm::formatv(
"{0} {0}::getChecked(::mlir::Location loc{1}) {{\n"
" return Base::getChecked(loc{2});\n}\n",
typeClass,
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
typeDefParams),
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::JustParams, typeDefParams));
}
}
// Generate the builders specified by the user.
auto builderFmtCtx = FmtContext().addSubst("_ctxt", "context");
auto checkedBuilderFmtCtx = FmtContext()
.addSubst("_loc", "loc")
.addSubst("_ctxt", "loc.getContext()");
for (const TypeBuilder &builder : typeDef.getBuilders()) {
Optional<StringRef> body = builder.getBody();
Optional<StringRef> checkedBody =
genCheckedMethods ? builder.getCheckedBody() : llvm::None;
if (!body && !checkedBody)
continue;
std::string paramStr;
llvm::raw_string_ostream paramOS(paramStr);
llvm::interleaveComma(builder.getParameters(), paramOS,
[&](const TypeBuilder::Parameter &param) {
// Note: TypeBuilder parameters are guaranteed to
// have names.
paramOS << param.getCppType() << " "
<< *param.getName();
});
paramOS.flush();
// Emit the `get` variant of the builder.
if (body) {
os << llvm::formatv("{0} {0}::get(", typeClass);
if (!builder.hasInferredContextParameter()) {
os << "::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
tgfmt(*body, &builderFmtCtx).str());
} else {
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, *body);
}
}
// Emit the `getChecked` variant of the builder.
if (checkedBody) {
os << llvm::formatv("{0} {0}::getChecked(::mlir::Location loc",
typeClass);
if (!paramStr.empty())
os << ", " << paramStr;
os << llvm::formatv(") {{\n {0};\n}\n",
tgfmt(*checkedBody, &checkedBuilderFmtCtx));
}
}
}
/// Print all the typedef-specific definition code.
static void emitTypeDefDef(const TypeDef &typeDef, raw_ostream &os) {
NamespaceEmitter ns(os, typeDef.getDialect());
SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);
if (!parameters.empty()) {
// Emit the storage class, if requested and necessary.
if (typeDef.genStorageClass())
emitStorageClass(typeDef, os);
// Emit the builders for this type.
emitTypeBuilderDefs(typeDef, os, parameters);
// Generate accessor definitions only if we also generate the storage class.
// Otherwise, let the user define the exact accessor definition.
if (typeDef.genAccessors() && typeDef.genStorageClass()) {
// Emit the parameter accessors.
if (typeDef.genAccessors())
for (const TypeParameter &parameter : parameters) {
SmallString<16> name = parameter.getName();
name[0] = llvm::toUpper(name[0]);
@ -505,16 +615,7 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
parameter.getCppType(), name, parameter.getName(),
typeDef.getCppClassName());
}
// Generate getChecked() method.
if (typeDef.genVerifyInvariantsDecl()) {
os << llvm::formatv(
typeDefDefGetCheckeStr,
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
typeDef.getCppClassName(),
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
parameters));
}
}
// If mnemonic is specified maybe print definitions for the parser and printer
@ -535,8 +636,8 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
// The parser dispatch is just a list of if-elses, matching on the
// mnemonic and calling the class's parse function.
os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext *"
"ctxt, "
"::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
"context, ::mlir::DialectAsmParser &parser, "
"::llvm::StringRef mnemonic) {\n";
for (const TypeDef &type : types) {
if (type.getMnemonic()) {
os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return "
@ -547,9 +648,9 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
// If the type has no parameters and no parser code, just invoke a normal
// `get`.
if (type.getNumParameters() == 0 && !type.getParserCode())
os << "get(ctxt);\n";
os << "get(context);\n";
else
os << "parse(ctxt, parser);\n";
os << "parse(context, parser);\n";
}
}
os << " return ::mlir::Type();\n";
@ -594,7 +695,7 @@ static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,
IfDefScope scope("GET_TYPEDEF_CLASSES", os);
emitParsePrintDispatch(typeDefs, os);
for (auto typeDef : typeDefs)
for (const TypeDef &typeDef : typeDefs)
emitTypeDefDef(typeDef, os);
return false;