forked from OSchip/llvm-project
309 lines
9.7 KiB
C++
309 lines
9.7 KiB
C++
//===- StructsGen.cpp - MLIR struct utility generator ---------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// StructsGen generates common utility functions for grouping attributes into a
|
|
// set of structured data.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/TableGen/Attribute.h"
|
|
#include "mlir/TableGen/Format.h"
|
|
#include "mlir/TableGen/GenInfo.h"
|
|
#include "mlir/TableGen/Operator.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "llvm/TableGen/Error.h"
|
|
#include "llvm/TableGen/Record.h"
|
|
#include "llvm/TableGen/TableGenBackend.h"
|
|
|
|
using llvm::raw_ostream;
|
|
using llvm::Record;
|
|
using llvm::RecordKeeper;
|
|
using llvm::StringRef;
|
|
using mlir::tblgen::FmtContext;
|
|
using mlir::tblgen::StructAttr;
|
|
|
|
static void
|
|
emitStructClass(const Record &structDef, StringRef structName,
|
|
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
|
|
StringRef description, raw_ostream &os) {
|
|
const char *structInfo = R"(
|
|
// {0}
|
|
class {1} : public ::mlir::DictionaryAttr)";
|
|
const char *structInfoEnd = R"( {
|
|
public:
|
|
using ::mlir::DictionaryAttr::DictionaryAttr;
|
|
static bool classof(::mlir::Attribute attr);
|
|
)";
|
|
os << formatv(structInfo, description, structName) << structInfoEnd;
|
|
|
|
// Declares a constructor function for the tablegen structure.
|
|
// TblgenStruct::get(MLIRContext context, Type1 Field1, Type2 Field2, ...);
|
|
const char *getInfoDecl = " static {0} get(\n";
|
|
const char *getInfoDeclArg = " {0} {1},\n";
|
|
const char *getInfoDeclEnd = " ::mlir::MLIRContext* context);\n\n";
|
|
|
|
os << llvm::formatv(getInfoDecl, structName);
|
|
|
|
for (auto field : fields) {
|
|
auto name = field.getName();
|
|
auto type = field.getType();
|
|
auto storage = type.getStorageType();
|
|
os << llvm::formatv(getInfoDeclArg, storage, name);
|
|
}
|
|
os << getInfoDeclEnd;
|
|
|
|
// Declares an accessor for the fields owned by the tablegen structure.
|
|
// namespace::storage TblgenStruct::field1() const;
|
|
const char *fieldInfo = R"( {0} {1}() const;
|
|
)";
|
|
for (auto field : fields) {
|
|
auto name = field.getName();
|
|
auto type = field.getType();
|
|
auto storage = type.getStorageType();
|
|
os << formatv(fieldInfo, storage, name);
|
|
}
|
|
|
|
os << "};\n\n";
|
|
}
|
|
|
|
static void emitStructDecl(const Record &structDef, raw_ostream &os) {
|
|
StructAttr structAttr(&structDef);
|
|
StringRef structName = structAttr.getStructClassName();
|
|
StringRef cppNamespace = structAttr.getCppNamespace();
|
|
StringRef description = structAttr.getSummary();
|
|
auto fields = structAttr.getAllFields();
|
|
|
|
// Wrap in the appropriate namespace.
|
|
llvm::SmallVector<StringRef, 2> namespaces;
|
|
llvm::SplitString(cppNamespace, namespaces, "::");
|
|
|
|
for (auto ns : namespaces)
|
|
os << "namespace " << ns << " {\n";
|
|
|
|
// Emit the struct class definition
|
|
emitStructClass(structDef, structName, fields, description, os);
|
|
|
|
// Close the declared namespace.
|
|
for (auto ns : namespaces)
|
|
os << "} // namespace " << ns << "\n";
|
|
}
|
|
|
|
static bool emitStructDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("Struct Utility Declarations", os);
|
|
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr");
|
|
for (const auto *def : defs) {
|
|
emitStructDecl(*def, os);
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
static void emitFactoryDef(llvm::StringRef structName,
|
|
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
|
|
raw_ostream &os) {
|
|
const char *getInfoDecl = "{0} {0}::get(\n";
|
|
const char *getInfoDeclArg = " {0} {1},\n";
|
|
const char *getInfoDeclEnd = " ::mlir::MLIRContext* context) {";
|
|
|
|
os << llvm::formatv(getInfoDecl, structName);
|
|
|
|
for (auto field : fields) {
|
|
auto name = field.getName();
|
|
auto type = field.getType();
|
|
auto storage = type.getStorageType();
|
|
os << llvm::formatv(getInfoDeclArg, storage, name);
|
|
}
|
|
os << getInfoDeclEnd;
|
|
|
|
const char *fieldStart = R"(
|
|
::llvm::SmallVector<::mlir::NamedAttribute, {0}> fields;
|
|
)";
|
|
os << llvm::formatv(fieldStart, fields.size());
|
|
|
|
const char *getFieldInfo = R"(
|
|
assert({0});
|
|
auto {0}_id = ::mlir::StringAttr::get(context, "{0}");
|
|
fields.emplace_back({0}_id, {0});
|
|
)";
|
|
|
|
const char *getFieldInfoOptional = R"(
|
|
if ({0}) {
|
|
auto {0}_id = ::mlir::StringAttr::get(context, "{0}");
|
|
fields.emplace_back({0}_id, {0});
|
|
}
|
|
)";
|
|
|
|
for (auto field : fields) {
|
|
if (field.getType().isOptional() || field.getType().hasDefaultValue())
|
|
os << llvm::formatv(getFieldInfoOptional, field.getName());
|
|
else
|
|
os << llvm::formatv(getFieldInfo, field.getName());
|
|
}
|
|
|
|
const char *getEndInfo = R"(
|
|
::mlir::Attribute dict = ::mlir::DictionaryAttr::get(context, fields);
|
|
return dict.dyn_cast<{0}>();
|
|
}
|
|
)";
|
|
os << llvm::formatv(getEndInfo, structName);
|
|
}
|
|
|
|
static void emitClassofDef(llvm::StringRef structName,
|
|
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
|
|
raw_ostream &os) {
|
|
const char *classofInfo = R"(
|
|
bool {0}::classof(::mlir::Attribute attr))";
|
|
|
|
const char *classofInfoHeader = R"(
|
|
if (!attr)
|
|
return false;
|
|
auto derived = attr.dyn_cast<::mlir::DictionaryAttr>();
|
|
if (!derived)
|
|
return false;
|
|
int num_absent_attrs = 0;
|
|
)";
|
|
|
|
os << llvm::formatv(classofInfo, structName) << " {";
|
|
os << llvm::formatv(classofInfoHeader);
|
|
|
|
FmtContext fctx;
|
|
const char *classofArgInfo = R"(
|
|
auto {0} = derived.get("{0}");
|
|
if (!{0} || !({1}))
|
|
return false;
|
|
)";
|
|
const char *classofArgInfoOptional = R"(
|
|
auto {0} = derived.get("{0}");
|
|
if (!{0})
|
|
++num_absent_attrs;
|
|
else if (!({1}))
|
|
return false;
|
|
)";
|
|
for (auto field : fields) {
|
|
auto name = field.getName();
|
|
auto type = field.getType();
|
|
std::string condition =
|
|
std::string(tgfmt(type.getConditionTemplate(), &fctx.withSelf(name)));
|
|
if (type.isOptional() || type.hasDefaultValue())
|
|
os << llvm::formatv(classofArgInfoOptional, name, condition);
|
|
else
|
|
os << llvm::formatv(classofArgInfo, name, condition);
|
|
}
|
|
|
|
const char *classofEndInfo = R"(
|
|
return derived.size() + num_absent_attrs == {0};
|
|
}
|
|
)";
|
|
os << llvm::formatv(classofEndInfo, fields.size());
|
|
}
|
|
|
|
static void
|
|
emitAccessorDef(llvm::StringRef structName,
|
|
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
|
|
raw_ostream &os) {
|
|
const char *fieldInfo = R"(
|
|
{0} {2}::{1}() const {
|
|
auto derived = this->cast<::mlir::DictionaryAttr>();
|
|
auto {1} = derived.get("{1}");
|
|
assert({1} && "attribute not found.");
|
|
assert({1}.isa<{0}>() && "incorrect Attribute type found.");
|
|
return {1}.cast<{0}>();
|
|
}
|
|
)";
|
|
const char *fieldInfoOptional = R"(
|
|
{0} {2}::{1}() const {
|
|
auto derived = this->cast<::mlir::DictionaryAttr>();
|
|
auto {1} = derived.get("{1}");
|
|
if (!{1})
|
|
return nullptr;
|
|
assert({1}.isa<{0}>() && "incorrect Attribute type found.");
|
|
return {1}.cast<{0}>();
|
|
}
|
|
)";
|
|
const char *fieldInfoDefaultValued = R"(
|
|
{0} {2}::{1}() const {
|
|
auto derived = this->cast<::mlir::DictionaryAttr>();
|
|
auto {1} = derived.get("{1}");
|
|
if (!{1}) {
|
|
::mlir::Builder builder(getContext());
|
|
return {3};
|
|
}
|
|
assert({1}.isa<{0}>() && "incorrect Attribute type found.");
|
|
return {1}.cast<{0}>();
|
|
}
|
|
)";
|
|
FmtContext fmtCtx;
|
|
fmtCtx.withBuilder("builder");
|
|
|
|
for (auto field : fields) {
|
|
auto name = field.getName();
|
|
auto type = field.getType();
|
|
auto storage = type.getStorageType();
|
|
if (type.isOptional()) {
|
|
os << llvm::formatv(fieldInfoOptional, storage, name, structName);
|
|
} else if (type.hasDefaultValue()) {
|
|
std::string defaultValue = tgfmt(type.getConstBuilderTemplate(), &fmtCtx,
|
|
type.getDefaultValue());
|
|
os << llvm::formatv(fieldInfoDefaultValued, storage, name, structName,
|
|
defaultValue);
|
|
} else {
|
|
os << llvm::formatv(fieldInfo, storage, name, structName);
|
|
}
|
|
}
|
|
}
|
|
|
|
static void emitStructDef(const Record &structDef, raw_ostream &os) {
|
|
StructAttr structAttr(&structDef);
|
|
StringRef cppNamespace = structAttr.getCppNamespace();
|
|
StringRef structName = structAttr.getStructClassName();
|
|
mlir::tblgen::FmtContext ctx;
|
|
auto fields = structAttr.getAllFields();
|
|
|
|
llvm::SmallVector<StringRef, 2> namespaces;
|
|
llvm::SplitString(cppNamespace, namespaces, "::");
|
|
|
|
for (auto ns : namespaces)
|
|
os << "namespace " << ns << " {\n";
|
|
|
|
emitFactoryDef(structName, fields, os);
|
|
emitClassofDef(structName, fields, os);
|
|
emitAccessorDef(structName, fields, os);
|
|
|
|
for (auto ns : llvm::reverse(namespaces))
|
|
os << "} // namespace " << ns << "\n";
|
|
}
|
|
|
|
static bool emitStructDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("Struct Utility Definitions", os);
|
|
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr");
|
|
for (const auto *def : defs)
|
|
emitStructDef(*def, os);
|
|
|
|
return false;
|
|
}
|
|
|
|
// Registers the struct utility generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genStructDecls("gen-struct-attr-decls",
|
|
"Generate struct utility declarations",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitStructDecls(records, os);
|
|
});
|
|
|
|
// Registers the struct utility generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genStructDefs("gen-struct-attr-defs", "Generate struct utility definitions",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitStructDefs(records, os);
|
|
});
|