forked from OSchip/llvm-project
312 lines
11 KiB
C++
312 lines
11 KiB
C++
//===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// OpInterfacesGen generates definitions for operation interfaces.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "DocGenUtilities.h"
|
|
#include "mlir/Support/STLExtras.h"
|
|
#include "mlir/TableGen/Format.h"
|
|
#include "mlir/TableGen/GenInfo.h"
|
|
#include "mlir/TableGen/OpInterfaces.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "llvm/TableGen/Error.h"
|
|
#include "llvm/TableGen/Record.h"
|
|
#include "llvm/TableGen/TableGenBackend.h"
|
|
|
|
using namespace llvm;
|
|
using namespace mlir;
|
|
using mlir::tblgen::OpInterface;
|
|
using mlir::tblgen::OpInterfaceMethod;
|
|
|
|
// Emit the method name and argument list for the given method. If
|
|
// 'addOperationArg' is true, then an Operation* argument is added to the
|
|
// beginning of the argument list.
|
|
static void emitMethodNameAndArgs(const OpInterfaceMethod &method,
|
|
raw_ostream &os, bool addOperationArg) {
|
|
os << method.getName() << '(';
|
|
if (addOperationArg)
|
|
os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", ");
|
|
interleaveComma(method.getArguments(), os,
|
|
[&](const OpInterfaceMethod::Argument &arg) {
|
|
os << arg.type << " " << arg.name;
|
|
});
|
|
os << ')';
|
|
}
|
|
|
|
// Get an array of all OpInterface definitions but exclude those subclassing
|
|
// "DeclareOpInterfaceMethods".
|
|
static std::vector<Record *>
|
|
getAllOpInterfaceDefinitions(const RecordKeeper &recordKeeper) {
|
|
std::vector<Record *> defs =
|
|
recordKeeper.getAllDerivedDefinitions("OpInterface");
|
|
|
|
llvm::erase_if(defs, [](const Record *def) {
|
|
return def->isSubClassOf("DeclareOpInterfaceMethods");
|
|
});
|
|
return defs;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GEN: Interface definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void emitInterfaceDef(OpInterface &interface, raw_ostream &os) {
|
|
StringRef interfaceName = interface.getName();
|
|
|
|
// Insert the method definitions.
|
|
for (auto &method : interface.getMethods()) {
|
|
os << method.getReturnType() << " " << interfaceName << "::";
|
|
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
|
|
|
|
// Forward to the method on the concrete operation type.
|
|
os << " {\n return getImpl()->" << method.getName() << '(';
|
|
if (!method.isStatic())
|
|
os << "getOperation()" << (method.arg_empty() ? "" : ", ");
|
|
interleaveComma(
|
|
method.getArguments(), os,
|
|
[&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });
|
|
os << ");\n }\n";
|
|
}
|
|
}
|
|
|
|
static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("Operation Interface Definitions", os);
|
|
|
|
for (const auto *def : getAllOpInterfaceDefinitions(recordKeeper)) {
|
|
OpInterface interface(def);
|
|
emitInterfaceDef(interface, os);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GEN: Interface declarations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void emitConceptDecl(OpInterface &interface, raw_ostream &os) {
|
|
os << " class Concept {\n"
|
|
<< " public:\n"
|
|
<< " virtual ~Concept() = default;\n";
|
|
|
|
// Insert each of the pure virtual concept methods.
|
|
for (auto &method : interface.getMethods()) {
|
|
os << " virtual " << method.getReturnType() << " ";
|
|
emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
|
|
os << " = 0;\n";
|
|
}
|
|
os << " };\n";
|
|
}
|
|
|
|
static void emitModelDecl(OpInterface &interface, raw_ostream &os) {
|
|
os << " template<typename ConcreteOp>\n";
|
|
os << " class Model : public Concept {\npublic:\n";
|
|
|
|
// Insert each of the virtual method overrides.
|
|
for (auto &method : interface.getMethods()) {
|
|
os << " " << method.getReturnType() << " ";
|
|
emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
|
|
os << " final {\n";
|
|
|
|
// Provide a definition of the concrete op if this is non static.
|
|
if (!method.isStatic()) {
|
|
os << " auto op = llvm::cast<ConcreteOp>(tablegen_opaque_op);\n"
|
|
<< " (void)op;\n";
|
|
}
|
|
|
|
// Check for a provided body to the function.
|
|
if (auto body = method.getBody()) {
|
|
os << body << "\n }\n";
|
|
continue;
|
|
}
|
|
|
|
// Forward to the method on the concrete operation type.
|
|
os << " return " << (method.isStatic() ? "ConcreteOp::" : "op.");
|
|
|
|
// Add the arguments to the call.
|
|
os << method.getName() << '(';
|
|
interleaveComma(
|
|
method.getArguments(), os,
|
|
[&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });
|
|
os << ");\n }\n";
|
|
}
|
|
os << " };\n";
|
|
}
|
|
|
|
static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
|
|
StringRef interfaceName,
|
|
StringRef interfaceTraitsName) {
|
|
os << " template <typename ConcreteOp>\n "
|
|
<< llvm::formatv("struct Trait : public OpInterface<{0},"
|
|
" detail::{1}>::Trait<ConcreteOp> {{\n",
|
|
interfaceName, interfaceTraitsName);
|
|
|
|
// Insert the default implementation for any methods.
|
|
for (auto &method : interface.getMethods()) {
|
|
// Flag interface methods named verifyTrait.
|
|
if (method.getName() == "verifyTrait")
|
|
PrintFatalError(
|
|
formatv("'verifyTrait' method cannot be specified as interface "
|
|
"method for '{0}'; set 'verify' on OpInterfaceTrait instead",
|
|
interfaceName));
|
|
auto defaultImpl = method.getDefaultImplementation();
|
|
if (!defaultImpl)
|
|
continue;
|
|
|
|
os << " " << (method.isStatic() ? "static " : "") << method.getReturnType()
|
|
<< " ";
|
|
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
|
|
os << " {\n" << defaultImpl.getValue() << " }\n";
|
|
}
|
|
|
|
tblgen::FmtContext traitCtx;
|
|
traitCtx.withOp("op");
|
|
if (auto verify = interface.getVerify()) {
|
|
os << " static LogicalResult verifyTrait(Operation* op) {\n"
|
|
<< std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n";
|
|
}
|
|
|
|
os << " };\n";
|
|
}
|
|
|
|
static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {
|
|
StringRef interfaceName = interface.getName();
|
|
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
|
|
|
|
// Emit the traits struct containing the concept and model declarations.
|
|
os << "namespace detail {\n"
|
|
<< "struct " << interfaceTraitsName << " {\n";
|
|
emitConceptDecl(interface, os);
|
|
emitModelDecl(interface, os);
|
|
os << "};\n} // end namespace detail\n";
|
|
|
|
// Emit the main interface class declaration.
|
|
os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
|
|
"public:\n"
|
|
" using OpInterface<{1}, detail::{2}>::OpInterface;\n",
|
|
interfaceName, interfaceName, interfaceTraitsName);
|
|
|
|
// Emit the derived trait for the interface.
|
|
emitTraitDecl(interface, os, interfaceName, interfaceTraitsName);
|
|
|
|
// Insert the method declarations.
|
|
for (auto &method : interface.getMethods()) {
|
|
os << " " << method.getReturnType() << " ";
|
|
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
|
|
os << ";\n";
|
|
}
|
|
|
|
// Emit any extra declarations.
|
|
if (Optional<StringRef> extraDecls = interface.getExtraClassDeclaration())
|
|
os << *extraDecls << "\n";
|
|
|
|
os << "};\n";
|
|
}
|
|
|
|
static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("Operation Interface Declarations", os);
|
|
|
|
for (const auto *def : getAllOpInterfaceDefinitions(recordKeeper)) {
|
|
OpInterface interface(def);
|
|
emitInterfaceDecl(interface, os);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GEN: Interface documentation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Emit a string corresponding to a C++ type, followed by a space if necessary.
|
|
static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
|
|
type = type.trim();
|
|
os << type;
|
|
if (type.back() != '&' && type.back() != '*')
|
|
os << " ";
|
|
return os;
|
|
}
|
|
|
|
static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
|
|
OpInterface interface(&interfaceDef);
|
|
|
|
// Emit the interface name followed by the description.
|
|
os << "## " << interface.getName() << " (" << interfaceDef.getName() << ")";
|
|
if (auto description = interface.getDescription())
|
|
mlir::tblgen::emitDescription(*description, os);
|
|
|
|
// Emit the methods required by the interface.
|
|
os << "\n### Methods:\n";
|
|
for (const auto &method : interface.getMethods()) {
|
|
// Emit the method name.
|
|
os << "#### `" << method.getName() << "`\n\n```c++\n";
|
|
|
|
// Emit the method signature.
|
|
if (method.isStatic())
|
|
os << "static ";
|
|
emitCPPType(method.getReturnType(), os) << method.getName() << '(';
|
|
interleaveComma(method.getArguments(), os,
|
|
[&](const OpInterfaceMethod::Argument &arg) {
|
|
emitCPPType(arg.type, os) << arg.name;
|
|
});
|
|
os << ");\n```\n";
|
|
|
|
// Emit the description.
|
|
if (auto description = method.getDescription())
|
|
mlir::tblgen::emitDescription(*description, os);
|
|
|
|
// If the body is not provided, this method must be provided by the
|
|
// operation.
|
|
if (!method.getBody())
|
|
os << "\nNOTE: This method *must* be implemented by the operation.\n\n";
|
|
}
|
|
}
|
|
|
|
static bool emitInterfaceDocs(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
|
os << "# Operation Interface definition\n";
|
|
|
|
for (const auto *def : getAllOpInterfaceDefinitions(recordKeeper))
|
|
emitInterfaceDoc(*def, os);
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GEN: Interface registration hooks
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Registers the operation interface generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genInterfaceDecls("gen-op-interface-decls",
|
|
"Generate op interface declarations",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitInterfaceDecls(records, os);
|
|
});
|
|
|
|
// Registers the operation interface generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genInterfaceDefs("gen-op-interface-defs",
|
|
"Generate op interface definitions",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitInterfaceDefs(records, os);
|
|
});
|
|
|
|
// Registers the operation interface document generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genInterfaceDocs("gen-op-interface-doc",
|
|
"Generate op interface documentation",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitInterfaceDocs(records, os);
|
|
});
|