forked from OSchip/llvm-project
250 lines
8.7 KiB
C++
250 lines
8.7 KiB
C++
//===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
//
|
|
// OpInterfacesGen generates definitions for operation interfaces.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Support/STLExtras.h"
|
|
#include "mlir/TableGen/GenInfo.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "llvm/TableGen/Error.h"
|
|
#include "llvm/TableGen/Record.h"
|
|
#include "llvm/TableGen/TableGenBackend.h"
|
|
|
|
using namespace llvm;
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
// This struct represents a single method argument.
|
|
struct MethodArgument {
|
|
StringRef type, name;
|
|
};
|
|
|
|
// Wrapper class around a single interface method.
|
|
class OpInterfaceMethod {
|
|
public:
|
|
explicit OpInterfaceMethod(const llvm::Record *def) : def(def) {
|
|
llvm::DagInit *args = def->getValueAsDag("arguments");
|
|
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
|
|
arguments.push_back(
|
|
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
|
|
args->getArgNameStr(i)});
|
|
}
|
|
}
|
|
|
|
// Return the return type of this method.
|
|
StringRef getReturnType() const {
|
|
return def->getValueAsString("returnType");
|
|
}
|
|
|
|
// Return the name of this method.
|
|
StringRef getName() const { return def->getValueAsString("name"); }
|
|
|
|
// Return if this method is static.
|
|
bool isStatic() const { return def->isSubClassOf("StaticInterfaceMethod"); }
|
|
|
|
// Return the body for this method if it has one.
|
|
llvm::Optional<StringRef> getBody() const {
|
|
auto value = def->getValueAsString("body");
|
|
return value.empty() ? llvm::Optional<StringRef>() : value;
|
|
}
|
|
|
|
// Arguments.
|
|
ArrayRef<MethodArgument> getArguments() const { return arguments; }
|
|
bool arg_empty() const { return arguments.empty(); }
|
|
|
|
protected:
|
|
// The TableGen definition of this method.
|
|
const llvm::Record *def;
|
|
|
|
// The arguments of this method.
|
|
SmallVector<MethodArgument, 2> arguments;
|
|
};
|
|
|
|
// Wrapper class with helper methods for accessing OpInterfaces defined in
|
|
// TableGen.
|
|
class OpInterface {
|
|
public:
|
|
explicit OpInterface(const llvm::Record *def) : def(def) {
|
|
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
|
|
for (llvm::Init *init : listInit->getValues())
|
|
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
|
|
}
|
|
|
|
// Return the name of this interface.
|
|
StringRef getName() const { return def->getValueAsString("cppClassName"); }
|
|
|
|
// Return the methods of this interface.
|
|
ArrayRef<OpInterfaceMethod> getMethods() const { return methods; }
|
|
|
|
protected:
|
|
// The TableGen definition of this interface.
|
|
const llvm::Record *def;
|
|
|
|
// The methods of this interface.
|
|
SmallVector<OpInterfaceMethod, 8> methods;
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
// Emit the method name and argument list for the given method. If
|
|
// 'addOperationArg' is true, then an Operation* argument is added to the
|
|
// beginning of the argument list.
|
|
static void emitMethodNameAndArgs(const OpInterfaceMethod &method,
|
|
raw_ostream &os, bool addOperationArg) {
|
|
os << method.getName() << '(';
|
|
if (addOperationArg)
|
|
os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", ");
|
|
interleaveComma(method.getArguments(), os, [&](const MethodArgument &arg) {
|
|
os << arg.type << " " << arg.name;
|
|
});
|
|
os << ')';
|
|
}
|
|
|
|
static void emitInterfaceDef(const Record &interfaceDef, raw_ostream &os) {
|
|
OpInterface interface(&interfaceDef);
|
|
StringRef interfaceName = interface.getName();
|
|
|
|
// Insert the method definitions.
|
|
for (auto &method : interface.getMethods()) {
|
|
os << method.getReturnType() << " " << interfaceName << "::";
|
|
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
|
|
|
|
// Forward to the method on the concrete operation type.
|
|
os << " {\n return getImpl()->" << method.getName() << '(';
|
|
if (!method.isStatic())
|
|
os << "getOperation()" << (method.arg_empty() ? "" : ", ");
|
|
interleaveComma(method.getArguments(), os,
|
|
[&](const MethodArgument &arg) { os << arg.name; });
|
|
os << ");\n }\n";
|
|
}
|
|
}
|
|
|
|
static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("Operation Interface Definitions", os);
|
|
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("OpInterface");
|
|
for (const auto *def : defs)
|
|
emitInterfaceDef(*def, os);
|
|
return false;
|
|
}
|
|
|
|
static void emitConceptDecl(OpInterface &interface, raw_ostream &os) {
|
|
os << " class Concept {\n"
|
|
<< " public:\n"
|
|
<< " virtual ~Concept() = default;\n";
|
|
|
|
// Insert each of the pure virtual concept methods.
|
|
for (auto &method : interface.getMethods()) {
|
|
os << " virtual " << method.getReturnType() << " ";
|
|
emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
|
|
os << " = 0;\n";
|
|
}
|
|
os << " };\n";
|
|
}
|
|
|
|
static void emitModelDecl(OpInterface &interface, raw_ostream &os) {
|
|
os << " template<typename ConcreteOp>\n";
|
|
os << " class Model : public Concept {\npublic:\n";
|
|
|
|
// Insert each of the virtual method overrides.
|
|
for (auto &method : interface.getMethods()) {
|
|
os << " " << method.getReturnType() << " ";
|
|
emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
|
|
os << " final {\n";
|
|
|
|
// Provide a definition of the concrete op if this is non static.
|
|
if (!method.isStatic()) {
|
|
os << " auto op = llvm::cast<ConcreteOp>(tablegen_opaque_op);\n"
|
|
<< " (void)op;\n";
|
|
}
|
|
|
|
// Check for a provided body to the function.
|
|
if (auto body = method.getBody()) {
|
|
os << body << "\n }\n";
|
|
continue;
|
|
}
|
|
|
|
// Forward to the method on the concrete operation type.
|
|
os << " return " << (method.isStatic() ? "ConcreteOp::" : "op.");
|
|
|
|
// Add the arguments to the call.
|
|
os << method.getName() << '(';
|
|
interleaveComma(method.getArguments(), os,
|
|
[&](const MethodArgument &arg) { os << arg.name; });
|
|
os << ");\n }\n";
|
|
}
|
|
os << " };\n";
|
|
}
|
|
|
|
static void emitInterfaceDecl(const Record &interfaceDef, raw_ostream &os) {
|
|
OpInterface interface(&interfaceDef);
|
|
StringRef interfaceName = interface.getName();
|
|
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
|
|
|
|
// Emit the traits struct containing the concept and model declarations.
|
|
os << "namespace detail {\n"
|
|
<< "struct " << interfaceTraitsName << " {\n";
|
|
emitConceptDecl(interface, os);
|
|
emitModelDecl(interface, os);
|
|
os << "};\n} // end namespace detail\n";
|
|
|
|
// Emit the main interface class declaration.
|
|
os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
|
|
"public:\n"
|
|
" using OpInterface<{1}, detail::{2}>::OpInterface;\n",
|
|
interfaceName, interfaceName, interfaceTraitsName);
|
|
|
|
// Insert the method declarations.
|
|
for (auto &method : interface.getMethods()) {
|
|
os << " " << method.getReturnType() << " ";
|
|
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
|
|
os << ";\n";
|
|
}
|
|
os << "};\n";
|
|
}
|
|
|
|
static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("Operation Interface Declarations", os);
|
|
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("OpInterface");
|
|
for (const auto *def : defs)
|
|
emitInterfaceDecl(*def, os);
|
|
return false;
|
|
}
|
|
|
|
// Registers the operation interface generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genInterfaceDecls("gen-op-interface-decls",
|
|
"Generate op interface declarations",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitInterfaceDecls(records, os);
|
|
});
|
|
|
|
// Registers the operation interface generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genInterfaceDefs("gen-op-interface-defs",
|
|
"Generate op interface definitions",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitInterfaceDefs(records, os);
|
|
});
|