[mlir] Add support for generating Attribute classes for ODS

The support for attributes closely maps that of Types (basically 1-1) given that Attributes are defined in exactly the same way as Types. All of the current ODS TypeDef classes get an Attr equivalent. The generation of the attribute classes themselves share the same generator as types.

Differential Revision: https://reviews.llvm.org/D97589
This commit is contained in:
River Riddle 2021-03-03 16:37:32 -08:00
parent 201ebf211f
commit 83ef862fad
19 changed files with 1507 additions and 1051 deletions

View File

@ -2465,20 +2465,20 @@ def replaceWithValue;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Data type generation // Attribute and Type generation
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Class for defining a custom type getter. // Class for defining a custom getter.
// //
// TableGen generates several generic getter methods for each type by default, // TableGen generates several generic getter methods for each attribute and type
// corresponding to the specified dag parameters. If the default generated ones // by default, corresponding to the specified dag parameters. If the default
// cannot cover some use case, custom getters can be defined using instances of // generated ones cannot cover some use case, custom getters can be defined
// this class. // using instances of this class.
// //
// The signature of the `get` is always either: // The signature of the `get` is always either:
// //
// ```c++ // ```c++
// static <Type-Name> get(MLIRContext *context, <other-parameters>...) { // static <ClassName> get(MLIRContext *context, <other-parameters>...) {
// <body>... // <body>...
// } // }
// ``` // ```
@ -2486,7 +2486,7 @@ def replaceWithValue;
// or: // or:
// //
// ```c++ // ```c++
// static <TypeName> get(MLIRContext *context, <parameters>...); // static <ClassName> get(MLIRContext *context, <parameters>...);
// ``` // ```
// //
// To define a custom getter, the parameter list and body should be passed // To define a custom getter, the parameter list and body should be passed
@ -2503,7 +2503,7 @@ def replaceWithValue;
// type. For example, the following signature specification // type. For example, the following signature specification
// //
// ``` // ```
// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)> // AttrOrTypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)>
// ``` // ```
// //
// has an integer parameter and a float parameter with a default value. // has an integer parameter and a float parameter with a default value.
@ -2514,7 +2514,7 @@ def replaceWithValue;
// method should be invoked using `$_get`, e.g.: // method should be invoked using `$_get`, e.g.:
// //
// ``` // ```
// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg), [{ // AttrOrTypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg), [{
// return $_get($_ctxt, integerArg, floatArg); // return $_get($_ctxt, integerArg, floatArg);
// }]> // }]>
// ``` // ```
@ -2522,7 +2522,7 @@ def replaceWithValue;
// This is necessary because the `body` is also used to generate `getChecked` // This is necessary because the `body` is also used to generate `getChecked`
// methods, which have a different underlying `Base::get*` call. // methods, which have a different underlying `Base::get*` call.
// //
class TypeBuilder<dag parameters, code bodyCode = ""> { class AttrOrTypeBuilder<dag parameters, code bodyCode = ""> {
dag dagParams = parameters; dag dagParams = parameters;
code body = bodyCode; code body = bodyCode;
@ -2530,33 +2530,42 @@ class TypeBuilder<dag parameters, code bodyCode = ""> {
// is not implicitly added to the parameter list. // is not implicitly added to the parameter list.
bit hasInferredContextParam = 0; bit hasInferredContextParam = 0;
} }
class AttrBuilder<dag parameters, code bodyCode = "">
: AttrOrTypeBuilder<parameters, bodyCode>;
class TypeBuilder<dag parameters, code bodyCode = "">
: AttrOrTypeBuilder<parameters, bodyCode>;
// A class of TypeBuilder that is able to infer the MLIRContext parameter from // A class of AttrOrTypeBuilder that is able to infer the MLIRContext parameter
// one of the other builder parameters. Instances of this builder do not have // from one of the other builder parameters. Instances of this builder do not
// `MLIRContext *` implicitly added to the parameter list. // have `MLIRContext *` implicitly added to the parameter list.
class TypeBuilderWithInferredContext<dag parameters, code bodyCode = ""> class AttrOrTypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
: TypeBuilder<parameters, bodyCode> { : TypeBuilder<parameters, bodyCode> {
let hasInferredContextParam = 1; let hasInferredContextParam = 1;
} }
class AttrBuilderWithInferredContext<dag parameters, code bodyCode = "">
: AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
: AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
// Define a new type, named `name`, belonging to `dialect` that inherits from // Define a new attribute or type, named `name`, that inherits from the given
// the given C++ base class. // C++ base class.
class TypeDef<Dialect dialect, string name, class AttrOrTypeDef<string valueType, string name, string baseCppClass> {
string baseCppClass = "::mlir::Type"> // The name of the C++ base class to use for this def.
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type"> {
// The name of the C++ base class to use for this Type.
string cppBaseClassName = baseCppClass; string cppBaseClassName = baseCppClass;
// Additional, longer human-readable description of what the op does. // Additional, longer human-readable description of what the def does.
string description = ""; string description = "";
// Name of storage class to generate or use. // Name of storage class to generate or use.
string storageClass = name # "TypeStorage"; string storageClass = name # valueType # "Storage";
// Namespace (withing dialect c++ namespace) in which the storage class // Namespace (withing dialect c++ namespace) in which the storage class
// resides. // resides.
string storageNamespace = "detail"; string storageNamespace = "detail";
// Specify if the storage class is to be generated. // Specify if the storage class is to be generated.
bit genStorageClass = 1; bit genStorageClass = 1;
// Specify that the generated storage class has a constructor which is written // Specify that the generated storage class has a constructor which is written
// in C++. // in C++.
bit hasStorageCustomConstructor = 0; bit hasStorageCustomConstructor = 0;
@ -2568,38 +2577,38 @@ class TypeDef<Dialect dialect, string name,
// (ins // (ins
// "<c++ type>":$param1Name, // "<c++ type>":$param1Name,
// "<c++ type>":$param2Name, // "<c++ type>":$param2Name,
// TypeParameter<"c++ type", "param description">:$param3Name) // AttrOrTypeParameter<"c++ type", "param description">:$param3Name)
// TypeParameters (or more likely one of their subclasses) are required to add // AttrOrTypeParameters (or more likely one of their subclasses) are required
// more information about the parameter, specifically: // to add more information about the parameter, specifically:
// - Documentation // - Documentation
// - Code to allocate the parameter (if allocation is needed in the storage // - Code to allocate the parameter (if allocation is needed in the storage
// class constructor) // class constructor)
// //
// For example: // For example:
// (ins // (ins "int":$width,
// "int":$width,
// ArrayRefParameter<"bool", "list of bools">:$yesNoArray) // ArrayRefParameter<"bool", "list of bools">:$yesNoArray)
// //
// (ArrayRefParameter is a subclass of TypeParameter which has allocation code // (ArrayRefParameter is a subclass of AttrOrTypeParameter which has
// for re-allocating ArrayRefs. It is defined below.) // allocation code for re-allocating ArrayRefs. It is defined below.)
dag parameters = (ins); dag parameters = (ins);
// Custom type builder methods. // Custom builder methods.
// In addition to the custom builders provided here, and unless // In addition to the custom builders provided here, and unless
// skipDefaultBuilders is set, a default builder is generated with the // skipDefaultBuilders is set, a default builder is generated with the
// following signature: // following signature:
// //
// ```c++ // ```c++
// static <TypeName> get(MLIRContext *, <parameters>); // static <ClassName> get(MLIRContext *, <parameters>);
// ``` // ```
// //
// Note that builders should only be provided when a type has parameters. // Note that builders should only be provided when a def has parameters.
list<TypeBuilder> builders = ?; list<AttrOrTypeBuilder> builders = ?;
// Use the lowercased name as the keyword for parsing/printing. Specify only // Use the lowercased name as the keyword for parsing/printing. Specify only
// if you want tblgen to generate declarations and/or definitions of // if you want tblgen to generate declarations and/or definitions of
// printer/parser for this type. // the printer/parser.
string mnemonic = ?; string mnemonic = ?;
// If 'mnemonic' specified, // If 'mnemonic' specified,
// If null, generate just the declarations. // If null, generate just the declarations.
// If a non-empty code block, just use that code as the definition code. // If a non-empty code block, just use that code as the definition code.
@ -2607,29 +2616,53 @@ class TypeDef<Dialect dialect, string name,
code printer = ?; code printer = ?;
code parser = ?; code parser = ?;
// If set, generate accessors for each Type parameter. // If set, generate accessors for each parameter.
bit genAccessors = 1; bit genAccessors = 1;
// Avoid generating default get/getChecked functions. Custom get methods must // Avoid generating default get/getChecked functions. Custom get methods must
// be provided. // be provided.
bit skipDefaultBuilders = 0; bit skipDefaultBuilders = 0;
// Generate the verify and getChecked methods. // Generate the verify and getChecked methods.
bit genVerifyDecl = 0; bit genVerifyDecl = 0;
// Extra code to include in the class declaration. // Extra code to include in the class declaration.
code extraClassDeclaration = [{}]; code extraClassDeclaration = [{}];
}
// The predicate for when this type is used as a type constraint. // Define a new attribute, named `name`, belonging to `dialect` that inherits
// from the given C++ base class.
class AttrDef<Dialect dialect, string name,
string baseCppClass = "::mlir::Attribute">
: DialectAttr<dialect, CPred<"">, /*descr*/"">,
AttrOrTypeDef<"Attr", name, baseCppClass> {
// The name of the C++ Attribute class.
string cppClassName = name # "Attr";
// The predicate for when this def is used as a constraint.
let predicate = CPred<"$_self.isa<" # dialect.cppNamespace # let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
"::" # cppClassName # ">()">; "::" # cppClassName # ">()">;
}
// Define a new type, named `name`, belonging to `dialect` that inherits from
// the given C++ base class.
class TypeDef<Dialect dialect, string name,
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
AttrOrTypeDef<"Type", name, baseCppClass> {
// A constant builder provided when the type has no parameters. // A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters), let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace # "$_builder.getType<" # dialect.cppNamespace #
"::" # cppClassName # ">()", "::" # cppClassName # ">()",
""); "");
// The predicate for when this def is used as a constraint.
let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
"::" # cppClassName # ">()">;
} }
// 'Parameters' should be subclasses of this or simple strings (which is a // 'Parameters' should be subclasses of this or simple strings (which is a
// shorthand for TypeParameter<"C++Type">). // shorthand for AttrOrTypeParameter<"C++Type">).
class TypeParameter<string type, string desc> { class AttrOrTypeParameter<string type, string desc> {
// Custom memory allocation code for storage constructor. // Custom memory allocation code for storage constructor.
code allocator = ?; code allocator = ?;
// The C++ type of this parameter. // The C++ type of this parameter.
@ -2639,28 +2672,30 @@ class TypeParameter<string type, string desc> {
// The format string for the asm syntax (documentation only). // The format string for the asm syntax (documentation only).
string syntax = ?; string syntax = ?;
} }
class AttrParameter<string type, string desc> : AttrOrTypeParameter<type, desc>;
class TypeParameter<string type, string desc> : AttrOrTypeParameter<type, desc>;
// For StringRefs, which require allocation. // For StringRefs, which require allocation.
class StringRefParameter<string desc> : class StringRefParameter<string desc> :
TypeParameter<"::llvm::StringRef", desc> { AttrOrTypeParameter<"::llvm::StringRef", desc> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}]; let allocator = [{$_dst = $_allocator.copyInto($_self);}];
} }
// For standard ArrayRefs, which require allocation. // For standard ArrayRefs, which require allocation.
class ArrayRefParameter<string arrayOf, string desc> : class ArrayRefParameter<string arrayOf, string desc> :
TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}]; let allocator = [{$_dst = $_allocator.copyInto($_self);}];
} }
// For classes which require allocation and have their own allocateInto method. // For classes which require allocation and have their own allocateInto method.
class SelfAllocationParameter<string type, string desc> : class SelfAllocationParameter<string type, string desc> :
TypeParameter<type, desc> { AttrOrTypeParameter<type, desc> {
let allocator = [{$_dst = $_self.allocateInto($_allocator);}]; let allocator = [{$_dst = $_self.allocateInto($_allocator);}];
} }
// For ArrayRefs which contain things which allocate themselves. // For ArrayRefs which contain things which allocate themselves.
class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> : class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
let allocator = [{ let allocator = [{
llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields; llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields;
for (size_t i = 0, e = $_self.size(); i < e; ++i) for (size_t i = 0, e = $_self.size(); i < e; ++i)
@ -2669,5 +2704,4 @@ class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
}]; }];
} }
#endif // OP_BASE #endif // OP_BASE

View File

@ -1,4 +1,4 @@
//===-- TypeDef.h - Record wrapper for type definitions ---------*- C++ -*-===// //===-- AttrOrTypeDef.h - Wrapper for attr and type definitions -*- C++ -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@ -6,12 +6,13 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// //
// TypeDef wrapper to simplify using TableGen Record defining a MLIR type. // AttrOrTypeDef, AttrDef, and TypeDef wrappers to simplify using TableGen
// Record defining a MLIR attributes and types.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_TYPEDEF_H #ifndef MLIR_TABLEGEN_ATTRORTYPEDEF_H
#define MLIR_TABLEGEN_TYPEDEF_H #define MLIR_TABLEGEN_ATTRORTYPEDEF_H
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Builder.h" #include "mlir/TableGen/Builder.h"
@ -25,14 +26,14 @@ class SMLoc;
namespace mlir { namespace mlir {
namespace tblgen { namespace tblgen {
class Dialect; class Dialect;
class TypeParameter; class AttrOrTypeParameter;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TypeBuilder // AttrOrTypeBuilder
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Wrapper class that represents a Tablegen TypeBuilder. /// Wrapper class that represents a Tablegen AttrOrTypeBuilder.
class TypeBuilder : public Builder { class AttrOrTypeBuilder : public Builder {
public: public:
using Builder::Builder; using Builder::Builder;
@ -41,22 +42,22 @@ public:
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TypeDef // AttrOrTypeDef
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Wrapper class that contains a TableGen TypeDef's record and provides helper /// Wrapper class that contains a TableGen AttrOrTypeDef's record and provides
/// methods for accessing them. /// helper methods for accessing them.
class TypeDef { class AttrOrTypeDef {
public: public:
explicit TypeDef(const llvm::Record *def); explicit AttrOrTypeDef(const llvm::Record *def);
// Get the dialect for which this type belongs. // Get the dialect for which this def belongs.
Dialect getDialect() const; Dialect getDialect() const;
// Returns the name of this TypeDef record. // Returns the name of this AttrOrTypeDef record.
StringRef getName() const; StringRef getName() const;
// Query functions for the documentation of the operator. // Query functions for the documentation of the def.
bool hasDescription() const; bool hasDescription() const;
StringRef getDescription() const; StringRef getDescription() const;
bool hasSummary() const; bool hasSummary() const;
@ -65,13 +66,13 @@ public:
// Returns the name of the C++ class to generate. // Returns the name of the C++ class to generate.
StringRef getCppClassName() const; StringRef getCppClassName() const;
// Returns the name of the C++ base class to use when generating this type. // Returns the name of the C++ base class to use when generating this def.
StringRef getCppBaseClassName() const; StringRef getCppBaseClassName() const;
// Returns the name of the storage class for this type. // Returns the name of the storage class for this def.
StringRef getStorageClassName() const; StringRef getStorageClassName() const;
// Returns the C++ namespace for this types storage class. // Returns the C++ namespace for this def's storage class.
StringRef getStorageNamespace() const; StringRef getStorageNamespace() const;
// Returns true if we should generate the storage class. // Returns true if we should generate the storage class.
@ -80,10 +81,11 @@ public:
// Indicates whether or not to generate the storage class constructor. // Indicates whether or not to generate the storage class constructor.
bool hasStorageCustomConstructor() const; bool hasStorageCustomConstructor() const;
// Fill a list with this types parameters. See TypeDef in OpBase.td for // Fill a list with this def's parameters. See AttrOrTypeDef in OpBase.td for
// documentation of parameter usage. // documentation of parameter usage.
void getParameters(SmallVectorImpl<TypeParameter> &) const; void getParameters(SmallVectorImpl<AttrOrTypeParameter> &) const;
// Return the number of type parameters
// Return the number of parameters
unsigned getNumParameters() const; unsigned getNumParameters() const;
// Return the keyword/mnemonic to use in the printer/parser methods if we are // Return the keyword/mnemonic to use in the printer/parser methods if we are
@ -94,19 +96,18 @@ public:
// return a non-value. Otherwise, return the contents of that code block. // return a non-value. Otherwise, return the contents of that code block.
Optional<StringRef> getPrinterCode() const; Optional<StringRef> getPrinterCode() const;
// Returns the code to use as the types parser method. If not specified, // Returns the code to use as the parser method. If not specified, returns
// return a non-value. Otherwise, return the contents of that code block. // None. Otherwise, returns the contents of that code block.
Optional<StringRef> getParserCode() const; Optional<StringRef> getParserCode() const;
// Returns true if the accessors based on the types parameters should be // Returns true if the accessors based on the parameters should be generated.
// generated.
bool genAccessors() const; bool genAccessors() const;
// Return true if we need to generate the verify declaration and getChecked // Return true if we need to generate the verify declaration and getChecked
// method. // method.
bool genVerifyDecl() const; bool genVerifyDecl() const;
// Returns the dialects extra class declaration code. // Returns the def's extra class declaration code.
Optional<StringRef> getExtraDecls() const; Optional<StringRef> getExtraDecls() const;
// Get the code location (for error printing). // Get the code location (for error printing).
@ -116,54 +117,80 @@ public:
// generation. // generation.
bool skipDefaultBuilders() const; bool skipDefaultBuilders() const;
// Returns the builders of this type. // Returns the builders of this def.
ArrayRef<TypeBuilder> getBuilders() const { return builders; } ArrayRef<AttrOrTypeBuilder> getBuilders() const { return builders; }
// Returns whether two TypeDefs are equal by checking the equality of the // Returns whether two AttrOrTypeDefs are equal by checking the equality of
// underlying record. // the underlying record.
bool operator==(const TypeDef &other) const; bool operator==(const AttrOrTypeDef &other) const;
// Compares two TypeDefs by comparing the names of the dialects. // Compares two AttrOrTypeDefs by comparing the names of the dialects.
bool operator<(const TypeDef &other) const; bool operator<(const AttrOrTypeDef &other) const;
// Returns whether the TypeDef is defined. // Returns whether the AttrOrTypeDef is defined.
operator bool() const { return def != nullptr; } operator bool() const { return def != nullptr; }
private: private:
const llvm::Record *def; const llvm::Record *def;
// The builders of this type definition. // The builders of this type definition.
SmallVector<TypeBuilder> builders; SmallVector<AttrOrTypeBuilder> builders;
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TypeParameter // AttrDef
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs /// This class represents a wrapper around a tablegen AttrDef record.
// to parameterize them. class AttrDef : public AttrOrTypeDef {
class TypeParameter {
public: public:
explicit TypeParameter(const llvm::DagInit *def, unsigned num) using AttrOrTypeDef::AttrOrTypeDef;
: def(def), num(num) {} };
//===----------------------------------------------------------------------===//
// TypeDef
//===----------------------------------------------------------------------===//
/// This class represents a wrapper around a tablegen TypeDef record.
class TypeDef : public AttrOrTypeDef {
public:
using AttrOrTypeDef::AttrOrTypeDef;
};
//===----------------------------------------------------------------------===//
// AttrOrTypeParameter
//===----------------------------------------------------------------------===//
// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to
// AttrOrTypeDefs to parameterize them.
class AttrOrTypeParameter {
public:
explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index)
: def(def), index(index) {}
// Get the parameter name. // Get the parameter name.
StringRef getName() const; StringRef getName() const;
// If specified, get the custom allocator code for this parameter. // If specified, get the custom allocator code for this parameter.
Optional<StringRef> getAllocator() const; Optional<StringRef> getAllocator() const;
// Get the C++ type of this parameter. // Get the C++ type of this parameter.
StringRef getCppType() const; StringRef getCppType() const;
// Get a description of this parameter for documentation purposes. // Get a description of this parameter for documentation purposes.
Optional<StringRef> getSummary() const; Optional<StringRef> getSummary() const;
// Get the assembly syntax documentation. // Get the assembly syntax documentation.
StringRef getSyntax() const; StringRef getSyntax() const;
private: private:
/// The underlying tablegen parameter list this parameter is a part of.
const llvm::DagInit *def; const llvm::DagInit *def;
const unsigned num; /// The index of the parameter within the parameter list (`def`).
unsigned index;
}; };
} // end namespace tblgen } // end namespace tblgen
} // end namespace mlir } // end namespace mlir
#endif // MLIR_TABLEGEN_TYPEDEF_H #endif // MLIR_TABLEGEN_ATTRORTYPEDEF_H

View File

@ -23,14 +23,15 @@ namespace tblgen {
// Simple RAII helper for defining ifdef-undef-endif scopes. // Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope { class IfDefScope {
public: public:
IfDefScope(llvm::StringRef name, llvm::raw_ostream &os) : name(name), os(os) { IfDefScope(llvm::StringRef name, llvm::raw_ostream &os)
: name(name.str()), os(os) {
os << "#ifdef " << name << "\n" os << "#ifdef " << name << "\n"
<< "#undef " << name << "\n\n"; << "#undef " << name << "\n\n";
} }
~IfDefScope() { os << "\n#endif // " << name << "\n\n"; } ~IfDefScope() { os << "\n#endif // " << name << "\n\n"; }
private: private:
llvm::StringRef name; std::string name;
llvm::raw_ostream &os; llvm::raw_ostream &os;
}; };

View File

@ -0,0 +1,221 @@
//===- AttrOrTypeDef.cpp - AttrOrTypeDef wrapper classes ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Dialect.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// AttrOrTypeBuilder
//===----------------------------------------------------------------------===//
/// Returns true if this builder is able to infer the MLIRContext parameter.
bool AttrOrTypeBuilder::hasInferredContextParameter() const {
return def->getValueAsBit("hasInferredContextParam");
}
//===----------------------------------------------------------------------===//
// AttrOrTypeDef
//===----------------------------------------------------------------------===//
AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
// Populate the builders.
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (llvm::Init *init : builderList->getValues()) {
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
def->getLoc());
// Ensure that all parameters have names.
for (const AttrOrTypeBuilder::Parameter &param :
builder.getParameters()) {
if (!param.getName())
PrintFatalError(def->getLoc(), "builder parameters must have a name");
}
builders.emplace_back(builder);
}
} else if (skipDefaultBuilders()) {
PrintFatalError(
def->getLoc(),
"default builders are skipped and no custom builders provided");
}
}
Dialect AttrOrTypeDef::getDialect() const {
auto *dialect = dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
return Dialect(dialect ? dialect->getDef() : nullptr);
}
StringRef AttrOrTypeDef::getName() const { return def->getName(); }
StringRef AttrOrTypeDef::getCppClassName() const {
return def->getValueAsString("cppClassName");
}
StringRef AttrOrTypeDef::getCppBaseClassName() const {
return def->getValueAsString("cppBaseClassName");
}
bool AttrOrTypeDef::hasDescription() const {
const llvm::RecordVal *desc = def->getValue("description");
return desc && isa<llvm::StringInit>(desc->getValue());
}
StringRef AttrOrTypeDef::getDescription() const {
return def->getValueAsString("description");
}
bool AttrOrTypeDef::hasSummary() const {
const llvm::RecordVal *summary = def->getValue("summary");
return summary && isa<llvm::StringInit>(summary->getValue());
}
StringRef AttrOrTypeDef::getSummary() const {
return def->getValueAsString("summary");
}
StringRef AttrOrTypeDef::getStorageClassName() const {
return def->getValueAsString("storageClass");
}
StringRef AttrOrTypeDef::getStorageNamespace() const {
return def->getValueAsString("storageNamespace");
}
bool AttrOrTypeDef::genStorageClass() const {
return def->getValueAsBit("genStorageClass");
}
bool AttrOrTypeDef::hasStorageCustomConstructor() const {
return def->getValueAsBit("hasStorageCustomConstructor");
}
void AttrOrTypeDef::getParameters(
SmallVectorImpl<AttrOrTypeParameter> &parameters) const {
if (auto *parametersDag = def->getValueAsDag("parameters")) {
for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i)
parameters.push_back(AttrOrTypeParameter(parametersDag, i));
}
}
unsigned AttrOrTypeDef::getNumParameters() const {
auto *parametersDag = def->getValueAsDag("parameters");
return parametersDag ? parametersDag->getNumArgs() : 0;
}
Optional<StringRef> AttrOrTypeDef::getMnemonic() const {
return def->getValueAsOptionalString("mnemonic");
}
Optional<StringRef> AttrOrTypeDef::getPrinterCode() const {
return def->getValueAsOptionalString("printer");
}
Optional<StringRef> AttrOrTypeDef::getParserCode() const {
return def->getValueAsOptionalString("parser");
}
bool AttrOrTypeDef::genAccessors() const {
return def->getValueAsBit("genAccessors");
}
bool AttrOrTypeDef::genVerifyDecl() const {
return def->getValueAsBit("genVerifyDecl");
}
Optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? Optional<StringRef>() : value;
}
ArrayRef<llvm::SMLoc> AttrOrTypeDef::getLoc() const { return def->getLoc(); }
bool AttrOrTypeDef::skipDefaultBuilders() const {
return def->getValueAsBit("skipDefaultBuilders");
}
bool AttrOrTypeDef::operator==(const AttrOrTypeDef &other) const {
return def == other.def;
}
bool AttrOrTypeDef::operator<(const AttrOrTypeDef &other) const {
return getName() < other.getName();
}
//===----------------------------------------------------------------------===//
// AttrOrTypeParameter
//===----------------------------------------------------------------------===//
StringRef AttrOrTypeParameter::getName() const {
return def->getArgName(index)->getValue();
}
Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
llvm::Init *parameterType = def->getArg(index);
if (isa<llvm::StringInit>(parameterType))
return Optional<StringRef>();
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
llvm::RecordVal *code = param->getDef()->getValue("allocator");
if (!code)
return Optional<StringRef>();
if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(code->getValue()))
return ci->getValue();
if (isa<llvm::UnsetInit>(code->getValue()))
return Optional<StringRef>();
llvm::PrintFatalError(
param->getDef()->getLoc(),
"Record `" + def->getArgName(index)->getValue() +
"', field `printer' does not have a code initializer!");
}
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from AttrOrTypeParameter\n");
}
StringRef AttrOrTypeParameter::getCppType() const {
auto *parameterType = def->getArg(index);
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
return stringType->getValue();
if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
return param->getDef()->getValueAsString("cppType");
llvm::PrintFatalError(
"Parameters DAG arguments must be either strings or defs "
"which inherit from AttrOrTypeParameter\n");
}
Optional<StringRef> AttrOrTypeParameter::getSummary() const {
auto *parameterType = def->getArg(index);
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
const auto *desc = param->getDef()->getValue("summary");
if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
return ci->getValue();
}
return Optional<StringRef>();
}
StringRef AttrOrTypeParameter::getSyntax() const {
auto *parameterType = def->getArg(index);
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
return stringType->getValue();
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
const auto *syntax = param->getDef()->getValue("syntax");
if (syntax && isa<llvm::StringInit>(syntax->getValue()))
return cast<llvm::StringInit>(syntax->getValue())->getValue();
return getCppType();
}
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from AttrOrTypeParameter");
}

View File

@ -11,6 +11,7 @@
llvm_add_library(MLIRTableGen STATIC llvm_add_library(MLIRTableGen STATIC
Argument.cpp Argument.cpp
Attribute.cpp Attribute.cpp
AttrOrTypeDef.cpp
Builder.cpp Builder.cpp
Constraint.cpp Constraint.cpp
Dialect.cpp Dialect.cpp
@ -26,7 +27,6 @@ llvm_add_library(MLIRTableGen STATIC
SideEffects.cpp SideEffects.cpp
Successor.cpp Successor.cpp
Type.cpp Type.cpp
TypeDef.cpp
DISABLE_LLVM_LINK_LLVM_DYLIB DISABLE_LLVM_LINK_LLVM_DYLIB

View File

@ -1,212 +0,0 @@
//===- TypeDef.cpp - TypeDef wrapper class --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// TypeDef wrapper to simplify using TableGen Record defining a MLIR dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/TypeDef.h"
#include "mlir/TableGen/Dialect.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// TypeBuilder
//===----------------------------------------------------------------------===//
/// Returns true if this builder is able to infer the MLIRContext parameter.
bool TypeBuilder::hasInferredContextParameter() const {
return def->getValueAsBit("hasInferredContextParam");
}
//===----------------------------------------------------------------------===//
// TypeDef
//===----------------------------------------------------------------------===//
Dialect TypeDef::getDialect() const {
auto *dialectDef =
dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
if (dialectDef == nullptr)
return Dialect(nullptr);
return Dialect(dialectDef->getDef());
}
StringRef TypeDef::getName() const { return def->getName(); }
StringRef TypeDef::getCppClassName() const {
return def->getValueAsString("cppClassName");
}
StringRef TypeDef::getCppBaseClassName() const {
return def->getValueAsString("cppBaseClassName");
}
bool TypeDef::hasDescription() const {
const llvm::RecordVal *s = def->getValue("description");
return s != nullptr && isa<llvm::StringInit>(s->getValue());
}
StringRef TypeDef::getDescription() const {
return def->getValueAsString("description");
}
bool TypeDef::hasSummary() const {
const llvm::RecordVal *s = def->getValue("summary");
return s != nullptr && isa<llvm::StringInit>(s->getValue());
}
StringRef TypeDef::getSummary() const {
return def->getValueAsString("summary");
}
StringRef TypeDef::getStorageClassName() const {
return def->getValueAsString("storageClass");
}
StringRef TypeDef::getStorageNamespace() const {
return def->getValueAsString("storageNamespace");
}
bool TypeDef::genStorageClass() const {
return def->getValueAsBit("genStorageClass");
}
bool TypeDef::hasStorageCustomConstructor() const {
return def->getValueAsBit("hasStorageCustomConstructor");
}
void TypeDef::getParameters(SmallVectorImpl<TypeParameter> &parameters) const {
auto *parametersDag = def->getValueAsDag("parameters");
if (parametersDag != nullptr) {
size_t numParams = parametersDag->getNumArgs();
for (unsigned i = 0; i < numParams; i++)
parameters.push_back(TypeParameter(parametersDag, i));
}
}
unsigned TypeDef::getNumParameters() const {
auto *parametersDag = def->getValueAsDag("parameters");
return parametersDag ? parametersDag->getNumArgs() : 0;
}
llvm::Optional<StringRef> TypeDef::getMnemonic() const {
return def->getValueAsOptionalString("mnemonic");
}
llvm::Optional<StringRef> TypeDef::getPrinterCode() const {
return def->getValueAsOptionalString("printer");
}
llvm::Optional<StringRef> TypeDef::getParserCode() const {
return def->getValueAsOptionalString("parser");
}
bool TypeDef::genAccessors() const {
return def->getValueAsBit("genAccessors");
}
bool TypeDef::genVerifyDecl() const {
return def->getValueAsBit("genVerifyDecl");
}
llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
llvm::ArrayRef<llvm::SMLoc> TypeDef::getLoc() const { return def->getLoc(); }
bool TypeDef::skipDefaultBuilders() const {
return def->getValueAsBit("skipDefaultBuilders");
}
bool TypeDef::operator==(const TypeDef &other) const {
return def == other.def;
}
bool TypeDef::operator<(const TypeDef &other) const {
return getName() < other.getName();
}
//===----------------------------------------------------------------------===//
// TypeParameter
//===----------------------------------------------------------------------===//
TypeDef::TypeDef(const llvm::Record *def) : def(def) {
// Populate the builders.
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (llvm::Init *init : builderList->getValues()) {
TypeBuilder builder(cast<llvm::DefInit>(init)->getDef(), def->getLoc());
// Ensure that all parameters have names.
for (const TypeBuilder::Parameter &param : builder.getParameters()) {
if (!param.getName())
PrintFatalError(def->getLoc(),
"type builder parameters must have a name");
}
builders.emplace_back(builder);
}
} else if (skipDefaultBuilders()) {
PrintFatalError(
def->getLoc(),
"default builders are skipped and no custom builders provided");
}
}
StringRef TypeParameter::getName() const {
return def->getArgName(num)->getValue();
}
Optional<StringRef> TypeParameter::getAllocator() const {
llvm::Init *parameterType = def->getArg(num);
if (isa<llvm::StringInit>(parameterType))
return llvm::Optional<StringRef>();
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
llvm::RecordVal *code = typeParameter->getDef()->getValue("allocator");
if (!code)
return llvm::Optional<StringRef>();
if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(code->getValue()))
return ci->getValue();
if (isa<llvm::UnsetInit>(code->getValue()))
return llvm::Optional<StringRef>();
llvm::PrintFatalError(
typeParameter->getDef()->getLoc(),
"Record `" + def->getArgName(num)->getValue() +
"', field `printer' does not have a code initializer!");
}
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from TypeParameter\n");
}
StringRef TypeParameter::getCppType() const {
auto *parameterType = def->getArg(num);
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
return stringType->getValue();
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType))
return typeParameter->getDef()->getValueAsString("cppType");
llvm::PrintFatalError(
"Parameters DAG arguments must be either strings or defs "
"which inherit from TypeParameter\n");
}
Optional<StringRef> TypeParameter::getSummary() const {
auto *parameterType = def->getArg(num);
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
const auto *desc = typeParameter->getDef()->getValue("summary");
if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
return ci->getValue();
}
return Optional<StringRef>();
}
StringRef TypeParameter::getSyntax() const {
auto *parameterType = def->getArg(num);
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
return stringType->getValue();
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
const auto *syntax = typeParameter->getDef()->getValue("syntax");
if (syntax && isa<llvm::StringInit>(syntax->getValue()))
return dyn_cast<llvm::StringInit>(syntax->getValue())->getValue();
return getCppType();
}
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from TypeParameter");
}

View File

@ -11,10 +11,15 @@ mlir_tablegen(TestOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(TestOpInterfaces.cpp.inc -gen-op-interface-defs) mlir_tablegen(TestOpInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRTestInterfaceIncGen) add_public_tablegen_target(MLIRTestInterfaceIncGen)
set(LLVM_TARGET_DEFINITIONS TestAttrDefs.td)
mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRTestAttrDefIncGen)
set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td) set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls) mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls)
mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs) mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRTestDefIncGen) add_public_tablegen_target(MLIRTestTypeDefIncGen)
set(LLVM_TARGET_DEFINITIONS TestOps.td) set(LLVM_TARGET_DEFINITIONS TestOps.td)
@ -30,6 +35,7 @@ add_public_tablegen_target(MLIRTestOpsIncGen)
# Exclude tests from libMLIR.so # Exclude tests from libMLIR.so
add_mlir_library(MLIRTestDialect add_mlir_library(MLIRTestDialect
TestAttributes.cpp
TestDialect.cpp TestDialect.cpp
TestInterfaces.cpp TestInterfaces.cpp
TestPatterns.cpp TestPatterns.cpp
@ -39,8 +45,9 @@ add_mlir_library(MLIRTestDialect
EXCLUDE_FROM_LIBMLIR EXCLUDE_FROM_LIBMLIR
DEPENDS DEPENDS
MLIRTestAttrDefIncGen
MLIRTestInterfaceIncGen MLIRTestInterfaceIncGen
MLIRTestDefIncGen MLIRTestTypeDefIncGen
MLIRTestOpsIncGen MLIRTestOpsIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC

View File

@ -0,0 +1,44 @@
//===-- TestAttrDefs.td - Test dialect attr definitions ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// TableGen data attribute definitions for Test dialect.
//
//===----------------------------------------------------------------------===//
#ifndef TEST_ATTRDEFS
#define TEST_ATTRDEFS
// To get the test dialect definition.
include "TestOps.td"
// All of the attributes will extend this class.
class Test_Attr<string name> : AttrDef<Test_Dialect, name>;
def SimpleAttrA : Test_Attr<"SimpleA"> {
let mnemonic = "smpla";
}
// A more complex parameterized attribute.
def CompoundAttrA : Test_Attr<"CompoundA"> {
let mnemonic = "cmpnd_a";
// List of type parameters.
let parameters = (
ins
"int":$widthOfSomething,
"::mlir::Type":$oneType,
// This is special syntax since ArrayRefs require allocation in the
// constructor.
ArrayRefParameter<
"int", // The parameter C++ type.
"An example of an array of ints" // Parameter description.
>: $arrayOfInts
);
}
#endif // TEST_ATTRDEFS

View File

@ -0,0 +1,82 @@
//===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains attributes defined by the TestDialect for testing various
// features of MLIR.
//
//===----------------------------------------------------------------------===//
#include "TestAttributes.h"
#include "TestDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::test;
Attribute CompoundAAttr::parse(MLIRContext *context, DialectAsmParser &parser,
Type type) {
int widthOfSomething;
Type oneType;
SmallVector<int, 4> arrayOfInts;
if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
parser.parseLSquare())
return Attribute();
int intVal;
while (!*parser.parseOptionalInteger(intVal)) {
arrayOfInts.push_back(intVal);
if (parser.parseOptionalComma())
break;
}
if (parser.parseRSquare() || parser.parseGreater())
return Attribute();
return get(context, widthOfSomething, oneType, arrayOfInts);
}
void CompoundAAttr::print(DialectAsmPrinter &printer) const {
printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType()
<< ", [";
llvm::interleaveComma(getArrayOfInts(), printer);
printer << "]>";
}
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "TestAttrDefs.cpp.inc"
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
Attribute TestDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
StringRef attrTag;
if (failed(parser.parseKeyword(&attrTag)))
return Attribute();
if (auto attr = generatedAttributeParser(getContext(), parser, attrTag, type))
return attr;
parser.emitError(parser.getNameLoc(), "unknown test attribute");
return Attribute();
}
void TestDialect::printAttribute(Attribute attr,
DialectAsmPrinter &printer) const {
if (succeeded(generatedAttributePrinter(attr, printer)))
return;
}

View File

@ -0,0 +1,27 @@
//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains types defined by the TestDialect for testing various
// features of MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TESTATTRIBUTES_H
#define MLIR_TESTATTRIBUTES_H
#include <tuple>
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#define GET_ATTRDEF_CLASSES
#include "TestAttrDefs.h.inc"
#endif // MLIR_TESTATTRIBUTES_H

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "TestDialect.h" #include "TestDialect.h"
#include "TestAttributes.h"
#include "TestTypes.h" #include "TestTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
@ -168,6 +169,10 @@ void TestDialect::initialize() {
#define GET_OP_LIST #define GET_OP_LIST
#include "TestOps.cpp.inc" #include "TestOps.cpp.inc"
>(); >();
addAttributes<
#define GET_ATTRDEF_LIST
#include "TestAttrDefs.cpp.inc"
>();
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
TestInlinerInterface>(); TestInlinerInterface>();
addTypes<TestType, TestRecursiveType, addTypes<TestType, TestRecursiveType,

View File

@ -27,6 +27,13 @@ def Test_Dialect : Dialect {
let hasOperationAttrVerify = 1; let hasOperationAttrVerify = 1;
let hasRegionArgAttrVerify = 1; let hasRegionArgAttrVerify = 1;
let hasRegionResultAttrVerify = 1; let hasRegionResultAttrVerify = 1;
let extraClassDeclaration = [{
Attribute parseAttribute(DialectAsmParser &parser,
Type type) const override;
void printAttribute(Attribute attr,
DialectAsmPrinter &printer) const override;
}];
} }
class TEST_Op<string mnemonic, list<OpTrait> traits = []> : class TEST_Op<string mnemonic, list<OpTrait> traits = []> :

View File

@ -0,0 +1,96 @@
// RUN: mlir-tblgen -gen-attrdef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
include "mlir/IR/OpBase.td"
// DECL: #ifdef GET_ATTRDEF_CLASSES
// DECL: #undef GET_ATTRDEF_CLASSES
// DECL: namespace mlir {
// DECL: class DialectAsmParser;
// DECL: class DialectAsmPrinter;
// DECL: } // namespace mlir
// DEF: #ifdef GET_ATTRDEF_LIST
// DEF: #undef GET_ATTRDEF_LIST
// DEF: ::mlir::test::SimpleAAttr,
// DEF: ::mlir::test::CompoundAAttr,
// DEF: ::mlir::test::IndexAttr,
// DEF: ::mlir::test::SingleParameterAttr
// DEF-LABEL: ::mlir::Attribute generatedAttributeParser(::mlir::MLIRContext *context,
// DEF-NEXT: ::mlir::DialectAsmParser &parser,
// DEF-NEXT: ::llvm::StringRef mnemonic, ::mlir::Type type) {
// DEF: if (mnemonic == ::mlir::test::CompoundAAttr::getMnemonic()) return ::mlir::test::CompoundAAttr::parse(context, parser, type);
// DEF-NEXT: if (mnemonic == ::mlir::test::IndexAttr::getMnemonic()) return ::mlir::test::IndexAttr::parse(context, parser, type);
// DEF-NEXT: return ::mlir::Attribute();
def Test_Dialect: Dialect {
// DECL-NOT: TestDialect
// DEF-NOT: TestDialect
let name = "TestDialect";
let cppNamespace = "::mlir::test";
}
class TestAttr<string name> : AttrDef<Test_Dialect, name> { }
def A_SimpleAttrA : TestAttr<"SimpleA"> {
// DECL: class SimpleAAttr : public ::mlir::Attribute
}
// A more complex parameterized type
def B_CompoundAttrA : TestAttr<"CompoundA"> {
let summary = "A more complex parameterized attribute";
let description = "This attribute is to test a reasonably complex attribute";
let mnemonic = "cmpnd_a";
let parameters = (
ins
"int":$widthOfSomething,
"::mlir::test::SimpleTypeA": $exampleTdType,
"SomeCppStruct": $exampleCppType,
ArrayRefParameter<"int", "Matrix dimensions">:$dims,
"::mlir::Type":$inner
);
let genVerifyDecl = 1;
// DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute
// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
// DECL: return ::llvm::StringLiteral("cmpnd_a");
// DECL: }
// DECL: static ::mlir::Attribute parse(::mlir::MLIRContext *context,
// DECL-NEXT: ::mlir::DialectAsmParser &parser, ::mlir::Type type);
// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
// DECL: int getWidthOfSomething() const;
// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
// DECL: SomeCppStruct getExampleCppType() const;
}
def C_IndexAttr : TestAttr<"Index"> {
let mnemonic = "index";
let parameters = (
ins
StringRefParameter<"Label for index">:$label
);
// DECL-LABEL: class IndexAttr : public ::mlir::Attribute
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
// DECL: return ::llvm::StringLiteral("index");
// DECL: }
// DECL: static ::mlir::Attribute parse(::mlir::MLIRContext *context,
// DECL-NEXT: ::mlir::DialectAsmParser &parser, ::mlir::Type type);
// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
}
def D_SingleParameterAttr : TestAttr<"SingleParameter"> {
let parameters = (
ins
"int": $num
);
// DECL-LABEL: struct SingleParameterAttrStorage;
// DECL-LABEL: class SingleParameterAttr
// DECL-NEXT: detail::SingleParameterAttrStorage
}

View File

@ -0,0 +1,5 @@
// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func private @compoundA()
// CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]>
func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [5, 6]>}

View File

@ -19,9 +19,11 @@ include "mlir/IR/OpBase.td"
// DEF: ::mlir::test::SingleParameterType, // DEF: ::mlir::test::SingleParameterType,
// DEF: ::mlir::test::IntegerType // DEF: ::mlir::test::IntegerType
// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &parser, ::llvm::StringRef mnemonic) // DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext *context,
// DEF-NEXT: ::mlir::DialectAsmParser &parser,
// DEF-NEXT: ::llvm::StringRef mnemonic) {
// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(context, parser); // DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(context, parser);
// DEF return ::mlir::Type(); // DEF: return ::mlir::Type();
def Test_Dialect: Dialect { def Test_Dialect: Dialect {
// DECL-NOT: TestDialect // DECL-NOT: TestDialect

View File

@ -0,0 +1,849 @@
//===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/TableGenBackend.h"
#define DEBUG_TYPE "mlir-tblgen-attrortypedefgen"
using namespace mlir;
using namespace mlir::tblgen;
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
/// specified and can only find one dialect's defs, use that.
static void collectAllDefs(StringRef selectedDialect,
std::vector<llvm::Record *> records,
SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
auto defs = llvm::map_range(
records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
if (defs.empty())
return;
StringRef dialectName;
if (selectedDialect.empty()) {
if (defs.empty())
return;
Dialect dialect(nullptr);
for (const AttrOrTypeDef &typeDef : defs) {
if (!dialect) {
dialect = typeDef.getDialect();
} else if (dialect != typeDef.getDialect()) {
llvm::PrintFatalError("defs belonging to more than one dialect. Must "
"select one via '--(attr|type)defs-dialect'");
}
}
dialectName = dialect.getName();
} else {
dialectName = selectedDialect;
}
for (const AttrOrTypeDef &def : defs)
if (def.getDialect().getName().equals(dialectName))
resultDefs.push_back(def);
}
//===----------------------------------------------------------------------===//
// ParamCommaFormatter
//===----------------------------------------------------------------------===//
namespace {
/// Pass an instance of this class to llvm::formatv() to emit a comma separated
/// list of parameters in the format by 'EmitFormat'.
class ParamCommaFormatter : public llvm::detail::format_adapter {
public:
/// Choose the output format
enum EmitFormat {
/// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name,
/// [...]".
TypeNamePairs,
/// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
TypeNameInitializer,
/// Emit "param1Name, param2Name, [...]".
JustParams,
};
ParamCommaFormatter(EmitFormat emitFormat,
ArrayRef<AttrOrTypeParameter> params,
bool prependComma = true)
: emitFormat(emitFormat), params(params), prependComma(prependComma) {}
/// llvm::formatv will call this function when using an instance as a
/// replacement value.
void format(raw_ostream &os, StringRef options) override {
if (!params.empty() && prependComma)
os << ", ";
switch (emitFormat) {
case EmitFormat::TypeNamePairs:
interleaveComma(params, os, [&](const AttrOrTypeParameter &p) {
emitTypeNamePair(p, os);
});
break;
case EmitFormat::TypeNameInitializer:
interleaveComma(params, os, [&](const AttrOrTypeParameter &p) {
emitTypeNameInitializer(p, os);
});
break;
case EmitFormat::JustParams:
interleaveComma(params, os,
[&](const AttrOrTypeParameter &p) { os << p.getName(); });
break;
}
}
private:
// Emit "paramType paramName".
static void emitTypeNamePair(const AttrOrTypeParameter &param,
raw_ostream &os) {
os << param.getCppType() << " " << param.getName();
}
// Emit "paramName(paramName)"
void emitTypeNameInitializer(const AttrOrTypeParameter &param,
raw_ostream &os) {
os << param.getName() << "(" << param.getName() << ")";
}
EmitFormat emitFormat;
ArrayRef<AttrOrTypeParameter> params;
bool prependComma;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// DefGenerator
//===----------------------------------------------------------------------===//
namespace {
/// This struct is the base generator used when processing tablegen interfaces.
class DefGenerator {
public:
bool emitDecls(StringRef selectedDialect);
bool emitDefs(StringRef selectedDialect);
protected:
DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os)
: defRecords(std::move(defs)), os(os), isAttrGenerator(false) {}
/// Emit the declaration of a single def.
void emitDefDecl(const AttrOrTypeDef &def);
/// Emit the list of def type names.
void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
/// Emit the code to dispatch between different defs during parsing/printing.
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
/// Emit the definition of a single def.
void emitDefDef(const AttrOrTypeDef &def);
/// Emit the storage class for the given def.
void emitStorageClass(const AttrOrTypeDef &def);
/// Emit the parser/printer for the given def.
void emitParsePrint(const AttrOrTypeDef &def);
/// The set of def records to emit.
std::vector<llvm::Record *> defRecords;
/// The stream to emit to.
raw_ostream &os;
/// The prefix of the tablegen def name, e.g. Attr or Type.
StringRef defTypePrefix;
/// The C++ base value type of the def, e.g. Attribute or Type.
StringRef valueType;
/// Flag indicating if this generator is for Attributes. False if the
/// generator is for types.
bool isAttrGenerator;
};
/// A specialized generator for AttrDefs.
struct AttrDefGenerator : public DefGenerator {
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitions("AttrDef"), os) {
isAttrGenerator = true;
defTypePrefix = "Attr";
valueType = "Attribute";
}
};
/// A specialized generator for TypeDefs.
struct TypeDefGenerator : public DefGenerator {
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitions("TypeDef"), os) {
defTypePrefix = "Type";
valueType = "Type";
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// GEN: Declarations
//===----------------------------------------------------------------------===//
/// Print this above all the other declarations. Contains type declarations used
/// later on.
static const char *const typeDefDeclHeader = R"(
namespace mlir {
class DialectAsmParser;
class DialectAsmPrinter;
} // namespace mlir
)";
/// The code block for the start of a typeDef class declaration -- singleton
/// case.
///
/// {0}: The name of the def class.
/// {1}: The name of the type base class.
/// {2}: The name of the base value type, e.g. Attribute or Type.
/// {3}: The tablegen record type prefix, e.g. Attr or Type.
static const char *const defDeclSingletonBeginStr = R"(
class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage> {{
public:
/// Inherit some necessary constructors from '{3}Base'.
using Base::Base;
)";
/// The code block for the start of a typeDef class declaration -- parametric
/// case.
///
/// {0}: The name of the typeDef class.
/// {1}: The name of the type base class.
/// {2}: The typeDef storage class namespace.
/// {3}: The storage class name.
/// {4}: The name of the base value type, e.g. Attribute or Type.
/// {5}: The tablegen record type prefix, e.g. Attr or Type.
static const char *const defDeclParametricBeginStr = R"(
namespace {2} {
struct {3};
} // end namespace {2}
class {0} : public ::mlir::{4}::{5}Base<{0}, {1},
{2}::{3}> {{
public:
/// Inherit some necessary constructors from '{5}Base'.
using Base::Base;
)";
/// The code snippet for print/parse of an Attribute/Type.
///
/// {0}: The name of the base value type, e.g. Attribute or Type.
/// {1}: Extra parser parameters.
static const char *const defDeclParsePrintStr = R"(
static ::mlir::{0} parse(::mlir::MLIRContext *context,
::mlir::DialectAsmParser &parser{1});
void print(::mlir::DialectAsmPrinter &printer) const;
)";
/// The code block for the verify method declaration.
///
/// {0}: List of parameters, parameters style.
static const char *const defDeclVerifyStr = R"(
using Base::getChecked;
static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0});
)";
/// Emit the builders for the given def.
static void emitBuilderDecls(const AttrOrTypeDef &def, raw_ostream &os,
ParamCommaFormatter &paramTypes) {
StringRef typeClass = def.getCppClassName();
bool genCheckedMethods = def.genVerifyDecl();
if (!def.skipDefaultBuilders()) {
os << llvm::formatv(
" static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
paramTypes);
if (genCheckedMethods) {
os << llvm::formatv(" static {0} "
"getChecked(llvm::function_ref<::mlir::"
"InFlightDiagnostic()> emitError, "
"::mlir::MLIRContext *context{1});\n",
typeClass, paramTypes);
}
}
// Generate the builders specified by the user.
for (const AttrOrTypeBuilder &builder : def.getBuilders()) {
std::string paramStr;
llvm::raw_string_ostream paramOS(paramStr);
llvm::interleaveComma(
builder.getParameters(), paramOS,
[&](const AttrOrTypeBuilder::Parameter &param) {
// Note: AttrOrTypeBuilder parameters are guaranteed to have names.
paramOS << param.getCppType() << " " << *param.getName();
if (Optional<StringRef> defaultParamValue = param.getDefaultValue())
paramOS << " = " << *defaultParamValue;
});
paramOS.flush();
// Generate the `get` variant of the builder.
os << " static " << typeClass << " get(";
if (!builder.hasInferredContextParameter()) {
os << "::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
}
os << paramStr << ");\n";
// Generate the `getChecked` variant of the builder.
if (genCheckedMethods) {
os << " static " << typeClass
<< " getChecked(llvm::function_ref<mlir::InFlightDiagnostic()> "
"emitError";
if (!builder.hasInferredContextParameter())
os << ", ::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
os << paramStr << ");\n";
}
}
}
void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
SmallVector<AttrOrTypeParameter, 4> params;
def.getParameters(params);
// Emit the beginning string template: either the singleton or parametric
// template.
if (def.getNumParameters() == 0) {
os << formatv(defDeclSingletonBeginStr, def.getCppClassName(),
def.getCppBaseClassName(), valueType, defTypePrefix);
} else {
os << formatv(defDeclParametricBeginStr, def.getCppClassName(),
def.getCppBaseClassName(), def.getStorageNamespace(),
def.getStorageClassName(), valueType, defTypePrefix);
}
// Emit the extra declarations first in case there's a definition in there.
if (Optional<StringRef> extraDecl = def.getExtraDecls())
os << *extraDecl << "\n";
ParamCommaFormatter emitTypeNamePairsAfterComma(
ParamCommaFormatter::EmitFormat::TypeNamePairs, params);
if (!params.empty()) {
emitBuilderDecls(def, os, emitTypeNamePairsAfterComma);
// Emit the verify invariants declaration.
if (def.genVerifyDecl())
os << llvm::formatv(defDeclVerifyStr, emitTypeNamePairsAfterComma);
}
// Emit the mnenomic, if specified.
if (auto mnenomic = def.getMnemonic()) {
os << " static constexpr ::llvm::StringLiteral getMnemonic() {\n"
<< " return ::llvm::StringLiteral(\"" << mnenomic << "\");\n"
<< " }\n";
// If mnemonic specified, emit print/parse declarations.
if (def.getParserCode() || def.getPrinterCode() || !params.empty()) {
os << llvm::formatv(defDeclParsePrintStr, valueType,
isAttrGenerator ? ", ::mlir::Type type" : "");
}
}
if (def.genAccessors()) {
SmallVector<AttrOrTypeParameter, 4> parameters;
def.getParameters(parameters);
for (AttrOrTypeParameter &parameter : parameters) {
SmallString<16> name = parameter.getName();
name[0] = llvm::toUpper(name[0]);
os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name);
}
}
// End the decl.
os << " };\n";
}
bool DefGenerator::emitDecls(StringRef selectedDialect) {
emitSourceFileHeader((defTypePrefix + "Def Declarations").str(), os);
IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os);
// Output the common "header".
os << typeDefDeclHeader;
SmallVector<AttrOrTypeDef, 16> defs;
collectAllDefs(selectedDialect, defRecords, defs);
if (defs.empty())
return false;
NamespaceEmitter nsEmitter(os, defs.front().getDialect());
// Declare all the def classes first (in case they reference each other).
for (const AttrOrTypeDef &def : defs)
os << " class " << def.getCppClassName() << ";\n";
// Emit the declarations.
for (const AttrOrTypeDef &def : defs)
emitDefDecl(def);
return false;
}
//===----------------------------------------------------------------------===//
// GEN: Def List
//===----------------------------------------------------------------------===//
void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_LIST", os);
auto interleaveFn = [&](const AttrOrTypeDef &def) {
os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
};
llvm::interleave(defs, os, interleaveFn, ",\n");
os << "\n";
}
//===----------------------------------------------------------------------===//
// GEN: Definitions
//===----------------------------------------------------------------------===//
/// The code block used to start the auto-generated parser function.
///
/// {0}: The name of the base value type, e.g. Attribute or Type.
/// {1}: Additional parser parameters.
static const char *const defParserDispatchStartStr = R"(
static ::mlir::{0} generated{0}Parser(::mlir::MLIRContext *context,
::mlir::DialectAsmParser &parser,
::llvm::StringRef mnemonic{1}) {{
)";
/// The code block used to start the auto-generated printer function.
///
/// {0}: The name of the base value type, e.g. Attribute or Type.
static const char *const defPrinterDispatchStartStr = R"(
static ::mlir::LogicalResult generated{0}Printer(
::mlir::{0} def, ::mlir::DialectAsmPrinter &printer) {{
return ::llvm::TypeSwitch<::mlir::{0}, ::mlir::LogicalResult>(def)
)";
/// Beginning of storage class.
/// {0}: Storage class namespace.
/// {1}: Storage class c++ name.
/// {2}: Parameters parameters.
/// {3}: Parameter initializer string.
/// {4}: Parameter name list.
/// {5}: Parameter types.
/// {6}: The name of the base value type, e.g. Attribute or Type.
static const char *const defStorageClassBeginStr = R"(
namespace {0} {{
struct {1} : public ::mlir::{6}Storage {{
{1} ({2})
: {3} {{ }
/// The hash key is a tuple of the parameter types.
using KeyTy = std::tuple<{5}>;
/// Define the comparison function for the key type.
bool operator==(const KeyTy &key) const {{
return key == KeyTy({4});
}
)";
/// The storage class' constructor template.
///
/// {0}: storage class name.
/// {1}: The name of the base value type, e.g. Attribute or Type.
static const char *const defStorageClassConstructorBeginStr = R"(
/// Define a construction method for creating a new instance of this
/// storage.
static {0} *construct(::mlir::{1}StorageAllocator &allocator,
const KeyTy &key) {{
)";
/// The storage class' constructor return template.
///
/// {0}: storage class name.
/// {1}: list of parameters.
static const char *const defStorageClassConstructorEndStr = R"(
return new (allocator.allocate<{0}>())
{0}({1});
}
)";
/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
static void emitStorageParameterAllocation(const AttrOrTypeDef &def,
raw_ostream &os) {
SmallVector<AttrOrTypeParameter> parameters;
def.getParameters(parameters);
FmtContext fmtCtxt = FmtContext().addSubst("_allocator", "allocator");
for (AttrOrTypeParameter &parameter : parameters) {
if (Optional<StringRef> allocCode = parameter.getAllocator()) {
fmtCtxt.withSelf(parameter.getName());
fmtCtxt.addSubst("_dst", parameter.getName());
os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n";
}
}
}
void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
SmallVector<AttrOrTypeParameter, 4> parameters;
def.getParameters(parameters);
// Collect the parameter names and types.
auto parameterNames =
map_range(parameters, [](AttrOrTypeParameter parameter) {
return parameter.getName();
});
auto parameterTypes =
map_range(parameters, [](AttrOrTypeParameter parameter) {
return parameter.getCppType();
});
auto parameterList = join(parameterNames, ", ");
auto parameterTypeList = join(parameterTypes, ", ");
// 1) Emit most of the storage class up until the hashKey body.
os << formatv(
defStorageClassBeginStr, def.getStorageNamespace(),
def.getStorageClassName(),
ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
parameters, /*prependComma=*/false),
ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNameInitializer,
parameters, /*prependComma=*/false),
parameterList, parameterTypeList, valueType);
// 2) Emit the haskKey method.
os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
// Extract each parameter from the key.
os << " return ::llvm::hash_combine(";
llvm::interleaveComma(
llvm::seq<unsigned>(0, parameters.size()), os,
[&](unsigned it) { os << "std::get<" << it << ">(key)"; });
os << ");\n }\n";
// 3) Emit the construct method.
// If user wants to build the storage constructor themselves, declare it
// here and then they can write the definition elsewhere.
if (def.hasStorageCustomConstructor()) {
os << llvm::formatv(" static {0} *construct(::mlir::{1}StorageAllocator "
"&allocator, const KeyTy &key);\n",
def.getStorageClassName(), valueType);
// Otherwise, generate one.
} else {
// First, unbox the parameters.
os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(),
valueType);
for (unsigned i = 0, e = parameters.size(); i < e; ++i) {
os << formatv(" auto {0} = std::get<{1}>(key);\n",
parameters[i].getName(), i);
}
// Second, reassign the parameter variables with allocation code, if it's
// specified.
emitStorageParameterAllocation(def, os);
// Last, return an allocated copy.
os << formatv(defStorageClassConstructorEndStr, def.getStorageClassName(),
parameterList);
}
// 4) Emit the parameters as storage class members.
for (auto parameter : parameters) {
os << " " << parameter.getCppType() << " " << parameter.getName()
<< ";\n";
}
os << " };\n";
os << "} // namespace " << def.getStorageNamespace() << "\n";
}
void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) {
// Emit the printer code, if specified.
if (Optional<StringRef> printerCode = def.getPrinterCode()) {
// Both the mnenomic and printerCode must be defined (for parity with
// parserCode).
os << "void " << def.getCppClassName()
<< "::print(::mlir::DialectAsmPrinter &printer) const {\n";
if (printerCode->empty()) {
// If no code specified, emit error.
PrintFatalError(def.getLoc(),
def.getName() +
": printer (if specified) must have non-empty code");
}
FmtContext fmtCtxt = FmtContext().addSubst("_printer", "printer");
os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n";
}
// Emit the parser code, if specified.
if (Optional<StringRef> parserCode = def.getParserCode()) {
FmtContext fmtCtxt;
fmtCtxt.addSubst("_parser", "parser").addSubst("_ctxt", "context");
// The mnenomic must be defined so the dispatcher knows how to dispatch.
os << llvm::formatv("::mlir::{0} {1}::parse(::mlir::MLIRContext *context, "
"::mlir::DialectAsmParser &parser",
valueType, def.getCppClassName());
if (isAttrGenerator) {
// Attributes also accept a type parameter instead of a context.
os << ", ::mlir::Type type";
fmtCtxt.addSubst("_type", "type");
}
os << ") {\n";
if (parserCode->empty()) {
PrintFatalError(def.getLoc(),
def.getName() +
": parser (if specified) must have non-empty code");
}
os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
}
}
/// Replace all instances of 'from' to 'to' in `str` and return the new string.
static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
size_t pos = 0;
while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
str.replace(pos, from.size(), to.data(), to.size());
return str;
}
/// Emit the builders for the given def.
static void emitBuilderDefs(const AttrOrTypeDef &def, raw_ostream &os,
ArrayRef<AttrOrTypeParameter> params) {
bool genCheckedMethods = def.genVerifyDecl();
StringRef className = def.getCppClassName();
if (!def.skipDefaultBuilders()) {
os << llvm::formatv(
"{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n"
" return Base::get(context{2});\n}\n",
className,
ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
params),
ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams,
params));
if (genCheckedMethods) {
os << llvm::formatv(
"{0} {0}::getChecked("
"llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, "
"::mlir::MLIRContext *context{1}) {{\n"
" return Base::getChecked(emitError, context{2});\n}\n",
className,
ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
params),
ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams,
params));
}
}
auto builderFmtCtx =
FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get");
auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get");
auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context");
// Generate the builders specified by the user.
for (const AttrOrTypeBuilder &builder : def.getBuilders()) {
Optional<StringRef> body = builder.getBody();
if (!body)
continue;
std::string paramStr;
llvm::raw_string_ostream paramOS(paramStr);
llvm::interleaveComma(builder.getParameters(), paramOS,
[&](const AttrOrTypeBuilder::Parameter &param) {
// Note: AttrOrTypeBuilder parameters are guaranteed
// to have names.
paramOS << param.getCppType() << " "
<< *param.getName();
});
paramOS.flush();
// Emit the `get` variant of the builder.
os << llvm::formatv("{0} {0}::get(", className);
if (!builder.hasInferredContextParameter()) {
os << "::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
tgfmt(*body, &builderFmtCtx).str());
} else {
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
tgfmt(*body, &inferredCtxBuilderFmtCtx).str());
}
// Emit the `getChecked` variant of the builder.
if (genCheckedMethods) {
os << llvm::formatv("{0} "
"{0}::getChecked(llvm::function_ref<::mlir::"
"InFlightDiagnostic()> emitErrorFn",
className);
std::string checkedBody =
replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, ");
if (!builder.hasInferredContextParameter()) {
os << ", ::mlir::MLIRContext *context";
checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str();
}
if (!paramStr.empty())
os << ", ";
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, checkedBody);
}
}
}
/// Print all the def-specific definition code.
void DefGenerator::emitDefDef(const AttrOrTypeDef &def) {
NamespaceEmitter ns(os, def.getDialect());
SmallVector<AttrOrTypeParameter, 4> parameters;
def.getParameters(parameters);
if (!parameters.empty()) {
// Emit the storage class, if requested and necessary.
if (def.genStorageClass())
emitStorageClass(def);
// Emit the builders for this def.
emitBuilderDefs(def, os, parameters);
// Generate accessor definitions only if we also generate the storage class.
// Otherwise, let the user define the exact accessor definition.
if (def.genAccessors() && def.genStorageClass()) {
for (const AttrOrTypeParameter &parameter : parameters) {
SmallString<16> name = parameter.getName();
name[0] = llvm::toUpper(name[0]);
os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n",
parameter.getCppType(), name, parameter.getName(),
def.getCppClassName());
}
}
}
// If mnemonic is specified maybe print definitions for the parser and printer
// code, if they're specified.
if (def.getMnemonic())
emitParsePrint(def);
}
/// Emit the dialect printer/parser dispatcher. User's code should call these
/// functions from their dialect's print/parse methods.
void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
if (llvm::none_of(defs, [](const AttrOrTypeDef &def) {
return def.getMnemonic().hasValue();
})) {
return;
}
// The parser dispatch is just a list of if-elses, matching on the mnemonic
// and calling the def's parse function.
os << llvm::formatv(defParserDispatchStartStr, valueType,
isAttrGenerator ? ", ::mlir::Type type" : "");
for (const AttrOrTypeDef &def : defs) {
if (def.getMnemonic()) {
os << formatv(
" if (mnemonic == {0}::{1}::getMnemonic()) return {0}::{1}::",
def.getDialect().getCppNamespace(), def.getCppClassName());
// If the def has no parameters and no parser code, just invoke a normal
// `get`.
if (def.getNumParameters() == 0 && !def.getParserCode()) {
os << "get(context);\n";
continue;
}
os << "parse(context, parser" << (isAttrGenerator ? ", type" : "")
<< ");\n";
}
}
os << " return ::mlir::" << valueType << "();\n";
os << "}\n\n";
// The printer dispatch uses llvm::TypeSwitch to find and call the correct
// printer.
os << llvm::formatv(defPrinterDispatchStartStr, valueType);
for (const AttrOrTypeDef &def : defs) {
Optional<StringRef> mnemonic = def.getMnemonic();
if (!mnemonic)
continue;
StringRef cppNamespace = def.getDialect().getCppNamespace();
StringRef cppClassName = def.getCppClassName();
os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ",
cppNamespace, cppClassName);
// If the def has no parameters and no printer, just print the mnemonic.
if (def.getNumParameters() == 0 && !def.getPrinterCode()) {
os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace,
cppClassName);
} else {
os << "t.print(printer);";
}
os << "\n return ::mlir::success();\n })\n";
}
os << llvm::formatv(
" .Default([](::mlir::{0}) {{ return ::mlir::failure(); });\n}\n\n",
valueType);
}
bool DefGenerator::emitDefs(StringRef selectedDialect) {
emitSourceFileHeader((defTypePrefix + "Def Definitions").str(), os);
SmallVector<AttrOrTypeDef, 16> defs;
collectAllDefs(selectedDialect, defRecords, defs);
if (defs.empty())
return false;
emitTypeDefList(defs);
IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os);
emitParsePrintDispatch(defs);
for (const AttrOrTypeDef &def : defs)
emitDefDef(def);
return false;
}
//===----------------------------------------------------------------------===//
// GEN: Registration hooks
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// AttrDef
static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
static llvm::cl::opt<std::string>
attrDialect("attrdefs-dialect",
llvm::cl::desc("Generate attributes for this dialect"),
llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated);
static mlir::GenRegistration
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDefs(attrDialect);
});
static mlir::GenRegistration
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDecls(attrDialect);
});
//===----------------------------------------------------------------------===//
// TypeDef
static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
static llvm::cl::opt<std::string>
typeDialect("typedefs-dialect",
llvm::cl::desc("Generate types for this dialect"),
llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
static mlir::GenRegistration
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDefs(typeDialect);
});
static mlir::GenRegistration
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect);
});

View File

@ -5,6 +5,7 @@ set(LLVM_LINK_COMPONENTS
) )
add_tablegen(mlir-tblgen MLIR add_tablegen(mlir-tblgen MLIR
AttrOrTypeDefGen.cpp
DialectGen.cpp DialectGen.cpp
DirectiveCommonGen.cpp DirectiveCommonGen.cpp
EnumsGen.cpp EnumsGen.cpp
@ -22,7 +23,6 @@ add_tablegen(mlir-tblgen MLIR
RewriterGen.cpp RewriterGen.cpp
SPIRVUtilsGen.cpp SPIRVUtilsGen.cpp
StructsGen.cpp StructsGen.cpp
TypeDefGen.cpp
) )
set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning") set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning")

View File

@ -13,9 +13,9 @@
#include "DocGenUtilities.h" #include "DocGenUtilities.h"
#include "mlir/Support/IndentedOstream.h" #include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/TypeDef.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
@ -164,7 +164,7 @@ static void emitTypeDoc(const Type &type, raw_ostream &os) {
/// Emit the assembly format of a type. /// Emit the assembly format of a type.
static void emitTypeAssemblyFormat(TypeDef td, raw_ostream &os) { static void emitTypeAssemblyFormat(TypeDef td, raw_ostream &os) {
SmallVector<TypeParameter, 4> parameters; SmallVector<AttrOrTypeParameter, 4> parameters;
td.getParameters(parameters); td.getParameters(parameters);
if (parameters.size() == 0) { if (parameters.size() == 0) {
os << "\nSyntax: `!" << td.getDialect().getName() << "." << td.getMnemonic() os << "\nSyntax: `!" << td.getDialect().getName() << "." << td.getMnemonic()
@ -198,7 +198,7 @@ static void emitTypeDefDoc(TypeDef td, raw_ostream &os) {
} }
// Emit attribute documentation. // Emit attribute documentation.
SmallVector<TypeParameter, 4> parameters; SmallVector<AttrOrTypeParameter, 4> parameters;
td.getParameters(parameters); td.getParameters(parameters);
if (!parameters.empty()) { if (!parameters.empty()) {
os << "\n#### Type parameters:\n\n"; os << "\n#### Type parameters:\n\n";

View File

@ -1,739 +0,0 @@
//===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// TypeDefGen uses the description of typeDefs to generate C++ definitions.
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/TypeDef.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/TableGenBackend.h"
#define DEBUG_TYPE "mlir-tblgen-typedefgen"
using namespace mlir;
using namespace mlir::tblgen;
static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
static llvm::cl::opt<std::string>
selectedDialect("typedefs-dialect",
llvm::cl::desc("Gen types for this dialect"),
llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
/// Find all the TypeDefs for the specified dialect. If no dialect specified and
/// can only find one dialect's types, use that.
static void findAllTypeDefs(const llvm::RecordKeeper &recordKeeper,
SmallVectorImpl<TypeDef> &typeDefs) {
auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
auto defs = llvm::map_range(
recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); });
if (defs.empty())
return;
StringRef dialectName;
if (selectedDialect.getNumOccurrences() == 0) {
if (defs.empty())
return;
llvm::SmallSet<Dialect, 4> dialects;
for (const TypeDef typeDef : defs)
dialects.insert(typeDef.getDialect());
if (dialects.size() != 1)
llvm::PrintFatalError("TypeDefs belonging to more than one dialect. Must "
"select one via '--typedefs-dialect'");
dialectName = (*dialects.begin()).getName();
} else if (selectedDialect.getNumOccurrences() == 1) {
dialectName = selectedDialect.getValue();
} else {
llvm::PrintFatalError("Cannot select multiple dialects for which to "
"generate types via '--typedefs-dialect'.");
}
for (const TypeDef typeDef : defs)
if (typeDef.getDialect().getName().equals(dialectName))
typeDefs.push_back(typeDef);
}
namespace {
/// Pass an instance of this class to llvm::formatv() to emit a comma separated
/// list of parameters in the format by 'EmitFormat'.
class TypeParamCommaFormatter : public llvm::detail::format_adapter {
public:
/// Choose the output format
enum EmitFormat {
/// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name,
/// [...]".
TypeNamePairs,
/// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
TypeNameInitializer,
/// Emit "param1Name, param2Name, [...]".
JustParams,
};
TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params,
bool prependComma = true)
: emitFormat(emitFormat), params(params), prependComma(prependComma) {}
/// llvm::formatv will call this function when using an instance as a
/// replacement value.
void format(raw_ostream &os, StringRef options) override {
if (!params.empty() && prependComma)
os << ", ";
switch (emitFormat) {
case EmitFormat::TypeNamePairs:
interleaveComma(params, os,
[&](const TypeParameter &p) { emitTypeNamePair(p, os); });
break;
case EmitFormat::TypeNameInitializer:
interleaveComma(params, os, [&](const TypeParameter &p) {
emitTypeNameInitializer(p, os);
});
break;
case EmitFormat::JustParams:
interleaveComma(params, os,
[&](const TypeParameter &p) { os << p.getName(); });
break;
}
}
private:
// Emit "paramType paramName".
static void emitTypeNamePair(const TypeParameter &param, raw_ostream &os) {
os << param.getCppType() << " " << param.getName();
}
// Emit "paramName(paramName)"
void emitTypeNameInitializer(const TypeParameter &param, raw_ostream &os) {
os << param.getName() << "(" << param.getName() << ")";
}
EmitFormat emitFormat;
ArrayRef<TypeParameter> params;
bool prependComma;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// GEN: TypeDef declarations
//===----------------------------------------------------------------------===//
/// Print this above all the other declarations. Contains type declarations used
/// later on.
static const char *const typeDefDeclHeader = R"(
namespace mlir {
class DialectAsmParser;
class DialectAsmPrinter;
} // namespace mlir
)";
/// The code block for the start of a typeDef class declaration -- singleton
/// case.
///
/// {0}: The name of the typeDef class.
/// {1}: The name of the type base class.
static const char *const typeDefDeclSingletonBeginStr = R"(
class {0} : public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
)";
/// The code block for the start of a typeDef class declaration -- parametric
/// case.
///
/// {0}: The name of the typeDef class.
/// {1}: The name of the type base class.
/// {2}: The typeDef storage class namespace.
/// {3}: The storage class name.
/// {4}: The list of parameters with types.
static const char *const typeDefDeclParametricBeginStr = R"(
namespace {2} {
struct {3};
} // end namespace {2}
class {0} : public ::mlir::Type::TypeBase<{0}, {1},
{2}::{3}> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
)";
/// The snippet for print/parse.
static const char *const typeDefParsePrint = R"(
static ::mlir::Type parse(::mlir::MLIRContext *context,
::mlir::DialectAsmParser &parser);
void print(::mlir::DialectAsmPrinter &printer) const;
)";
/// The code block for the verify method declaration.
///
/// {0}: List of parameters, parameters style.
static const char *const typeDefDeclVerifyStr = R"(
using Base::getChecked;
static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0});
)";
/// Emit the builders for the given type.
static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
TypeParamCommaFormatter &paramTypes) {
StringRef typeClass = typeDef.getCppClassName();
bool genCheckedMethods = typeDef.genVerifyDecl();
if (!typeDef.skipDefaultBuilders()) {
os << llvm::formatv(
" static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
paramTypes);
if (genCheckedMethods) {
os << llvm::formatv(" static {0} "
"getChecked(llvm::function_ref<::mlir::"
"InFlightDiagnostic()> emitError, "
"::mlir::MLIRContext *context{1});\n",
typeClass, paramTypes);
}
}
// Generate the builders specified by the user.
for (const TypeBuilder &builder : typeDef.getBuilders()) {
std::string paramStr;
llvm::raw_string_ostream paramOS(paramStr);
llvm::interleaveComma(
builder.getParameters(), paramOS,
[&](const TypeBuilder::Parameter &param) {
// Note: TypeBuilder parameters are guaranteed to have names.
paramOS << param.getCppType() << " " << *param.getName();
if (Optional<StringRef> defaultParamValue = param.getDefaultValue())
paramOS << " = " << *defaultParamValue;
});
paramOS.flush();
// Generate the `get` variant of the builder.
os << " static " << typeClass << " get(";
if (!builder.hasInferredContextParameter()) {
os << "::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
}
os << paramStr << ");\n";
// Generate the `getChecked` variant of the builder.
if (genCheckedMethods) {
os << " static " << typeClass
<< " getChecked(llvm::function_ref<mlir::InFlightDiagnostic()> "
"emitError";
if (!builder.hasInferredContextParameter())
os << ", ::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
os << paramStr << ");\n";
}
}
}
/// Generate the declaration for the given typeDef class.
static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> params;
typeDef.getParameters(params);
// Emit the beginning string template: either the singleton or parametric
// template.
if (typeDef.getNumParameters() == 0)
os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(),
typeDef.getCppBaseClassName());
else
os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(),
typeDef.getCppBaseClassName(), typeDef.getStorageNamespace(),
typeDef.getStorageClassName());
// Emit the extra declarations first in case there's a type definition in
// there.
if (Optional<StringRef> extraDecl = typeDef.getExtraDecls())
os << *extraDecl << "\n";
TypeParamCommaFormatter emitTypeNamePairsAfterComma(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
if (!params.empty()) {
emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma);
// Emit the verify invariants declaration.
if (typeDef.genVerifyDecl())
os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
}
// Emit the mnenomic, if specified.
if (auto mnenomic = typeDef.getMnemonic()) {
os << " static constexpr ::llvm::StringLiteral getMnemonic() {\n"
<< " return ::llvm::StringLiteral(\"" << mnenomic << "\");\n"
<< " }\n";
// If mnemonic specified, emit print/parse declarations.
if (typeDef.getParserCode() || typeDef.getPrinterCode() || !params.empty())
os << typeDefParsePrint;
}
if (typeDef.genAccessors()) {
SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);
for (TypeParameter &parameter : parameters) {
SmallString<16> name = parameter.getName();
name[0] = llvm::toUpper(name[0]);
os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name);
}
}
// End the typeDef decl.
os << " };\n";
}
/// Main entry point for decls.
static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
emitSourceFileHeader("TypeDef Declarations", os);
SmallVector<TypeDef, 16> typeDefs;
findAllTypeDefs(recordKeeper, typeDefs);
IfDefScope scope("GET_TYPEDEF_CLASSES", os);
// Output the common "header".
os << typeDefDeclHeader;
if (!typeDefs.empty()) {
NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());
// Declare all the type classes first (in case they reference each other).
for (const TypeDef &typeDef : typeDefs)
os << " class " << typeDef.getCppClassName() << ";\n";
// Declare all the typedefs.
for (const TypeDef &typeDef : typeDefs)
emitTypeDefDecl(typeDef, os);
}
return false;
}
//===----------------------------------------------------------------------===//
// GEN: TypeDef list
//===----------------------------------------------------------------------===//
static void emitTypeDefList(SmallVectorImpl<TypeDef> &typeDefs,
raw_ostream &os) {
IfDefScope scope("GET_TYPEDEF_LIST", os);
for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) {
os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName();
if (i < typeDefs.end() - 1)
os << ",\n";
else
os << "\n";
}
}
//===----------------------------------------------------------------------===//
// GEN: TypeDef definitions
//===----------------------------------------------------------------------===//
/// Beginning of storage class.
/// {0}: Storage class namespace.
/// {1}: Storage class c++ name.
/// {2}: Parameters parameters.
/// {3}: Parameter initializer string.
/// {4}: Parameter name list.
/// {5}: Parameter types.
static const char *const typeDefStorageClassBegin = R"(
namespace {0} {{
struct {1} : public ::mlir::TypeStorage {{
{1} ({2})
: {3} {{ }
/// The hash key for this storage is a pair of the integer and type params.
using KeyTy = std::tuple<{5}>;
/// Define the comparison function for the key type.
bool operator==(const KeyTy &key) const {{
return key == KeyTy({4});
}
)";
/// The storage class' constructor template.
/// {0}: storage class name.
static const char *const typeDefStorageClassConstructorBegin = R"(
/// Define a construction method for creating a new instance of this storage.
static {0} *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {{
)";
/// The storage class' constructor return template.
/// {0}: storage class name.
/// {1}: list of parameters.
static const char *const typeDefStorageClassConstructorReturn = R"(
return new (allocator.allocate<{0}>())
{0}({1});
}
)";
/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);
auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator");
for (TypeParameter &parameter : parameters) {
auto allocCode = parameter.getAllocator();
if (allocCode) {
fmtCtxt.withSelf(parameter.getName());
fmtCtxt.addSubst("_dst", parameter.getName());
os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n";
}
}
}
/// Emit the storage class code for type 'typeDef'.
/// This includes (in-order):
/// 1) typeDefStorageClassBegin, which includes:
/// - The class constructor.
/// - The KeyTy definition.
/// - The equality (==) operator.
/// 2) The hashKey method.
/// 3) The construct method.
/// 4) The list of parameters as the storage class member variables.
static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);
// Initialize a bunch of variables to be used later on.
auto parameterNames = map_range(
parameters, [](TypeParameter parameter) { return parameter.getName(); });
auto parameterTypes = map_range(parameters, [](TypeParameter parameter) {
return parameter.getCppType();
});
auto parameterList = join(parameterNames, ", ");
auto parameterTypeList = join(parameterTypes, ", ");
// 1) Emit most of the storage class up until the hashKey body.
os << formatv(typeDefStorageClassBegin, typeDef.getStorageNamespace(),
typeDef.getStorageClassName(),
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
parameters, /*prependComma=*/false),
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNameInitializer,
parameters, /*prependComma=*/false),
parameterList, parameterTypeList);
// 2) Emit the haskKey method.
os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
// Extract each parameter from the key.
for (size_t i = 0, e = parameters.size(); i < e; ++i)
os << llvm::formatv(" const auto &{0} = std::get<{1}>(key);\n",
parameters[i].getName(), i);
// Then combine them all. This requires all the parameters types to have a
// hash_value defined.
os << llvm::formatv(
" return ::llvm::hash_combine({0});\n }\n",
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
parameters, /* prependComma */ false));
// 3) Emit the construct method.
if (typeDef.hasStorageCustomConstructor()) {
// If user wants to build the storage constructor themselves, declare it
// here and then they can write the definition elsewhere.
os << " static " << typeDef.getStorageClassName()
<< " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
"&key);\n";
} else {
// If not, autogenerate one.
// First, unbox the parameters.
os << formatv(typeDefStorageClassConstructorBegin,
typeDef.getStorageClassName());
for (size_t i = 0; i < parameters.size(); ++i) {
os << formatv(" auto {0} = std::get<{1}>(key);\n",
parameters[i].getName(), i);
}
// Second, reassign the parameter variables with allocation code, if it's
// specified.
emitParameterAllocationCode(typeDef, os);
// Last, return an allocated copy.
os << formatv(typeDefStorageClassConstructorReturn,
typeDef.getStorageClassName(), parameterList);
}
// 4) Emit the parameters as storage class members.
for (auto parameter : parameters) {
os << " " << parameter.getCppType() << " " << parameter.getName()
<< ";\n";
}
os << " };\n";
os << "} // namespace " << typeDef.getStorageNamespace() << "\n";
}
/// Emit the parser and printer for a particular type, if they're specified.
void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
// Emit the printer code, if specified.
if (auto printerCode = typeDef.getPrinterCode()) {
// Both the mnenomic and printerCode must be defined (for parity with
// parserCode).
os << "void " << typeDef.getCppClassName()
<< "::print(::mlir::DialectAsmPrinter &printer) const {\n";
if (*printerCode == "") {
// If no code specified, emit error.
PrintFatalError(typeDef.getLoc(),
typeDef.getName() +
": printer (if specified) must have non-empty code");
}
auto fmtCtxt = FmtContext().addSubst("_printer", "printer");
os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n";
}
// emit a parser, if specified.
if (auto parserCode = typeDef.getParserCode()) {
// The mnenomic must be defined so the dispatcher knows how to dispatch.
os << "::mlir::Type " << typeDef.getCppClassName()
<< "::parse(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &"
"parser) "
"{\n";
if (*parserCode == "") {
// if no code specified, emit error.
PrintFatalError(typeDef.getLoc(),
typeDef.getName() +
": parser (if specified) must have non-empty code");
}
auto fmtCtxt =
FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "context");
os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
}
}
/// Replace all instances of 'from' to 'to' in `str` and return the new string.
static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
size_t pos = 0;
while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
str.replace(pos, from.size(), to.data(), to.size());
return str;
}
/// Emit the builders for the given type.
static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
ArrayRef<TypeParameter> typeDefParams) {
bool genCheckedMethods = typeDef.genVerifyDecl();
StringRef typeClass = typeDef.getCppClassName();
if (!typeDef.skipDefaultBuilders()) {
os << llvm::formatv(
"{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n"
" return Base::get(context{2});\n}\n",
typeClass,
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, typeDefParams),
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
typeDefParams));
if (genCheckedMethods) {
os << llvm::formatv(
"{0} {0}::getChecked("
"llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, "
"::mlir::MLIRContext *context{1}) {{\n"
" return Base::getChecked(emitError, context{2});\n}\n",
typeClass,
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
typeDefParams),
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::JustParams, typeDefParams));
}
}
auto builderFmtCtx =
FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get");
auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get");
auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context");
// Generate the builders specified by the user.
for (const TypeBuilder &builder : typeDef.getBuilders()) {
Optional<StringRef> body = builder.getBody();
if (!body)
continue;
std::string paramStr;
llvm::raw_string_ostream paramOS(paramStr);
llvm::interleaveComma(builder.getParameters(), paramOS,
[&](const TypeBuilder::Parameter &param) {
// Note: TypeBuilder parameters are guaranteed to
// have names.
paramOS << param.getCppType() << " "
<< *param.getName();
});
paramOS.flush();
// Emit the `get` variant of the builder.
os << llvm::formatv("{0} {0}::get(", typeClass);
if (!builder.hasInferredContextParameter()) {
os << "::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
tgfmt(*body, &builderFmtCtx).str());
} else {
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
tgfmt(*body, &inferredCtxBuilderFmtCtx).str());
}
// Emit the `getChecked` variant of the builder.
if (genCheckedMethods) {
os << llvm::formatv("{0} "
"{0}::getChecked(llvm::function_ref<::mlir::"
"InFlightDiagnostic()> emitErrorFn",
typeClass);
std::string checkedBody =
replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, ");
if (!builder.hasInferredContextParameter()) {
os << ", ::mlir::MLIRContext *context";
checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str();
}
if (!paramStr.empty())
os << ", ";
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, checkedBody);
}
}
}
/// Print all the typedef-specific definition code.
static void emitTypeDefDef(const TypeDef &typeDef, raw_ostream &os) {
NamespaceEmitter ns(os, typeDef.getDialect());
SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);
if (!parameters.empty()) {
// Emit the storage class, if requested and necessary.
if (typeDef.genStorageClass())
emitStorageClass(typeDef, os);
// Emit the builders for this type.
emitTypeBuilderDefs(typeDef, os, parameters);
// Generate accessor definitions only if we also generate the storage class.
// Otherwise, let the user define the exact accessor definition.
if (typeDef.genAccessors() && typeDef.genStorageClass()) {
// Emit the parameter accessors.
for (const TypeParameter &parameter : parameters) {
SmallString<16> name = parameter.getName();
name[0] = llvm::toUpper(name[0]);
os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
parameter.getCppType(), name, parameter.getName(),
typeDef.getCppClassName());
}
}
}
// If mnemonic is specified maybe print definitions for the parser and printer
// code, if they're specified.
if (typeDef.getMnemonic())
emitParserPrinter(typeDef, os);
}
/// Emit the dialect printer/parser dispatcher. User's code should call these
/// functions from their dialect's print/parse methods.
static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
if (llvm::none_of(types, [](const TypeDef &type) {
return type.getMnemonic().hasValue();
})) {
return;
}
// The parser dispatch is just a list of if-elses, matching on the
// mnemonic and calling the class's parse function.
os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext *"
"context, ::mlir::DialectAsmParser &parser, "
"::llvm::StringRef mnemonic) {\n";
for (const TypeDef &type : types) {
if (type.getMnemonic()) {
os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return "
"{0}::{1}::",
type.getDialect().getCppNamespace(),
type.getCppClassName());
// If the type has no parameters and no parser code, just invoke a normal
// `get`.
if (type.getNumParameters() == 0 && !type.getParserCode())
os << "get(context);\n";
else
os << "parse(context, parser);\n";
}
}
os << " return ::mlir::Type();\n";
os << "}\n\n";
// The printer dispatch uses llvm::TypeSwitch to find and call the correct
// printer.
os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type "
"type, "
"::mlir::DialectAsmPrinter &printer) {\n"
<< " return ::llvm::TypeSwitch<::mlir::Type, "
"::mlir::LogicalResult>(type)\n";
for (const TypeDef &type : types) {
if (Optional<StringRef> mnemonic = type.getMnemonic()) {
StringRef cppNamespace = type.getDialect().getCppNamespace();
StringRef cppClassName = type.getCppClassName();
os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ",
cppNamespace, cppClassName);
// If the type has no parameters and no printer code, just print the
// mnemonic.
if (type.getNumParameters() == 0 && !type.getPrinterCode())
os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace,
cppClassName);
else
os << "t.print(printer);";
os << "\n return ::mlir::success();\n })\n";
}
}
os << " .Default([](::mlir::Type) { return ::mlir::failure(); });\n"
<< "}\n\n";
}
/// Entry point for typedef definitions.
static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
emitSourceFileHeader("TypeDef Definitions", os);
SmallVector<TypeDef, 16> typeDefs;
findAllTypeDefs(recordKeeper, typeDefs);
emitTypeDefList(typeDefs, os);
IfDefScope scope("GET_TYPEDEF_CLASSES", os);
emitParsePrintDispatch(typeDefs, os);
for (const TypeDef &typeDef : typeDefs)
emitTypeDefDef(typeDef, os);
return false;
}
//===----------------------------------------------------------------------===//
// GEN: TypeDef registration hooks
//===----------------------------------------------------------------------===//
static mlir::GenRegistration
genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
return emitTypeDefDefs(records, os);
});
static mlir::GenRegistration
genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
return emitTypeDefDecls(records, os);
});