forked from OSchip/llvm-project
[mlir] Generate Dialect constructors in .cpp instead of .h
By generating in the .h file, we were forcing dialects to include a lot of additional header files because: * Fields of the dialect, e.g. std::unique_ptr<>, were unable to use forward declarations. * Dependent dialects are loaded in the constructor, requiring the full definition of each dependent dialect (which, depending on the file structure of the dialect, may include the operations). By generating in the .cpp we get much faster builds, and also better align with the rest of the code base. Fixes #55044 Differential Revision: https://reviews.llvm.org/D124297
This commit is contained in:
parent
a48300aee5
commit
f3ebf828dc
|
@ -87,16 +87,9 @@ findSelectedDialect(ArrayRef<const llvm::Record *> dialectDefs) {
|
||||||
///
|
///
|
||||||
/// {0}: The name of the dialect class.
|
/// {0}: The name of the dialect class.
|
||||||
/// {1}: The dialect namespace.
|
/// {1}: The dialect namespace.
|
||||||
/// {2}: initialization code that is emitted in the ctor body before calling
|
|
||||||
/// initialize()
|
|
||||||
static const char *const dialectDeclBeginStr = R"(
|
static const char *const dialectDeclBeginStr = R"(
|
||||||
class {0} : public ::mlir::Dialect {
|
class {0} : public ::mlir::Dialect {
|
||||||
explicit {0}(::mlir::MLIRContext *context)
|
explicit {0}(::mlir::MLIRContext *context);
|
||||||
: ::mlir::Dialect(getDialectNamespace(), context,
|
|
||||||
::mlir::TypeID::get<{0}>()) {{
|
|
||||||
{2}
|
|
||||||
initialize();
|
|
||||||
}
|
|
||||||
|
|
||||||
void initialize();
|
void initialize();
|
||||||
friend class ::mlir::MLIRContext;
|
friend class ::mlir::MLIRContext;
|
||||||
|
@ -190,23 +183,13 @@ emitDialectDecl(Dialect &dialect,
|
||||||
const iterator_range<DialectFilterIterator> &dialectAttrs,
|
const iterator_range<DialectFilterIterator> &dialectAttrs,
|
||||||
const iterator_range<DialectFilterIterator> &dialectTypes,
|
const iterator_range<DialectFilterIterator> &dialectTypes,
|
||||||
raw_ostream &os) {
|
raw_ostream &os) {
|
||||||
/// Build the list of dependent dialects
|
|
||||||
std::string dependentDialectRegistrations;
|
|
||||||
{
|
|
||||||
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
|
|
||||||
for (StringRef dependentDialect : dialect.getDependentDialects())
|
|
||||||
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
|
|
||||||
dependentDialect);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit all nested namespaces.
|
// Emit all nested namespaces.
|
||||||
{
|
{
|
||||||
NamespaceEmitter nsEmitter(os, dialect);
|
NamespaceEmitter nsEmitter(os, dialect);
|
||||||
|
|
||||||
// Emit the start of the decl.
|
// Emit the start of the decl.
|
||||||
std::string cppName = dialect.getCppClassName();
|
std::string cppName = dialect.getCppClassName();
|
||||||
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
|
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName());
|
||||||
dependentDialectRegistrations);
|
|
||||||
|
|
||||||
// Check for any attributes/types registered to this dialect. If there are,
|
// Check for any attributes/types registered to this dialect. If there are,
|
||||||
// add the hooks for parsing/printing.
|
// add the hooks for parsing/printing.
|
||||||
|
@ -262,6 +245,19 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
|
||||||
// GEN: Dialect definitions
|
// GEN: Dialect definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// The code block to generate a dialect constructor definition.
|
||||||
|
///
|
||||||
|
/// {0}: The name of the dialect class.
|
||||||
|
/// {1}: initialization code that is emitted in the ctor body before calling
|
||||||
|
/// initialize().
|
||||||
|
static const char *const dialectConstructorStr = R"(
|
||||||
|
{0}::{0}(::mlir::MLIRContext *context)
|
||||||
|
: ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
|
||||||
|
{1}
|
||||||
|
initialize();
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
/// The code block to generate a default desturctor definition.
|
/// The code block to generate a default desturctor definition.
|
||||||
///
|
///
|
||||||
/// {0}: The name of the dialect class.
|
/// {0}: The name of the dialect class.
|
||||||
|
@ -271,16 +267,30 @@ static const char *const dialectDestructorStr = R"(
|
||||||
)";
|
)";
|
||||||
|
|
||||||
static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
|
static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
|
||||||
|
std::string cppClassName = dialect.getCppClassName();
|
||||||
|
|
||||||
// Emit the TypeID explicit specializations to have a single symbol def.
|
// Emit the TypeID explicit specializations to have a single symbol def.
|
||||||
if (!dialect.getCppNamespace().empty())
|
if (!dialect.getCppNamespace().empty())
|
||||||
os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
|
os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
|
||||||
<< "::" << dialect.getCppClassName() << ")\n";
|
<< "::" << cppClassName << ")\n";
|
||||||
|
|
||||||
// Emit all nested namespaces.
|
// Emit all nested namespaces.
|
||||||
NamespaceEmitter nsEmitter(os, dialect);
|
NamespaceEmitter nsEmitter(os, dialect);
|
||||||
|
|
||||||
|
/// Build the list of dependent dialects.
|
||||||
|
std::string dependentDialectRegistrations;
|
||||||
|
{
|
||||||
|
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
|
||||||
|
for (StringRef dependentDialect : dialect.getDependentDialects())
|
||||||
|
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
|
||||||
|
dependentDialect);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit the constructor and destructor.
|
||||||
|
os << llvm::formatv(dialectConstructorStr, cppClassName,
|
||||||
|
dependentDialectRegistrations);
|
||||||
if (!dialect.hasNonDefaultDestructor())
|
if (!dialect.hasNonDefaultDestructor())
|
||||||
os << llvm::formatv(dialectDestructorStr, dialect.getCppClassName());
|
os << llvm::formatv(dialectDestructorStr, cppClassName);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
|
static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
|
||||||
|
|
Loading…
Reference in New Issue