[mlir] Add an 'cppNamespace' field to availability

This allows us to generate interfaces in a namespace,
following other TableGen'erated code.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D108311
This commit is contained in:
Lei Zhang 2021-10-05 09:32:35 -04:00
parent c483140f3c
commit 83e074a0c6
5 changed files with 35 additions and 13 deletions

View File

@ -19,6 +19,8 @@ include "mlir/IR/OpBase.td"
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
// The namespace for the generated C++ OpInterface subclass.
string cppNamespace = ?;
// The name for the generated C++ OpInterface subclass.
string interfaceName = ?;
// The documentation for the generated C++ OpInterface subclass.

View File

@ -125,6 +125,7 @@ def SPV_VersionAttr : SPV_I32EnumAttr<"Version", "valid SPIR-V version", [
class MinVersion<I32EnumAttrCase min> : MinVersionBase<
"QueryMinVersionInterface", SPV_VersionAttr, min> {
let cppNamespace = "::mlir::spirv";
let interfaceDescription = [{
Querying interface for minimal required SPIR-V version.
@ -136,6 +137,7 @@ class MinVersion<I32EnumAttrCase min> : MinVersionBase<
class MaxVersion<I32EnumAttrCase max> : MaxVersionBase<
"QueryMaxVersionInterface", SPV_VersionAttr, max> {
let cppNamespace = "::mlir::spirv";
let interfaceDescription = [{
Querying interface for maximal supported SPIR-V version.
@ -146,6 +148,7 @@ class MaxVersion<I32EnumAttrCase max> : MaxVersionBase<
}
class Extension<list<StrEnumAttrCase> extensions> : Availability {
let cppNamespace = "::mlir::spirv";
let interfaceName = "QueryExtensionInterface";
let interfaceDescription = [{
Querying interface for required SPIR-V extensions.
@ -189,6 +192,7 @@ class Extension<list<StrEnumAttrCase> extensions> : Availability {
}
class Capability<list<I32EnumAttrCase> capabilities> : Availability {
let cppNamespace = "::mlir::spirv";
let interfaceName = "QueryCapabilityInterface";
let interfaceDescription = [{
Querying interface for required SPIR-V capabilities.

View File

@ -22,15 +22,15 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
// TableGen'erated operation interfaces for querying versions, extensions, and
// capabilities.
#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.h.inc"
namespace mlir {
class OpBuilder;
namespace spirv {
class VerCapExtAttr;
// TableGen'erated operation interfaces for querying versions, extensions, and
// capabilities.
#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.h.inc"
} // namespace spirv
} // namespace mlir

View File

@ -3915,14 +3915,9 @@ static LogicalResult verify(spirv::PtrAccessChainOp accessChainOp) {
return verifyAccessChain(accessChainOp, accessChainOp.indices());
}
namespace mlir {
namespace spirv {
// TableGen'erated operation interfaces for querying versions, extensions, and
// capabilities.
#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
} // namespace spirv
} // namespace mlir
// TablenGen'erated operation definitions.
#define GET_OP_CLASSES
@ -3932,6 +3927,5 @@ namespace mlir {
namespace spirv {
// TableGen'erated operation availability interface implementations.
#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
} // namespace spirv
} // namespace mlir

View File

@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
@ -45,6 +46,7 @@ using mlir::tblgen::EnumAttr;
using mlir::tblgen::EnumAttrCase;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::NamedTypeConstraint;
using mlir::tblgen::NamespaceEmitter;
using mlir::tblgen::Operator;
//===----------------------------------------------------------------------===//
@ -62,6 +64,9 @@ public:
// instance.
StringRef getClass() const;
// Returns the generated C++ interface's class namespace.
StringRef getInterfaceClassNamespace() const;
// Returns the generated C++ interface's class name.
StringRef getInterfaceClassName() const;
@ -91,6 +96,9 @@ public:
// Returns the concrete availability instance carried in this case.
StringRef getMergeInstance() const;
// Returns the underlying LLVM TableGen Record.
const llvm::Record *getDef() const { return def; }
private:
// The TableGen definition of this availability.
const llvm::Record *def;
@ -112,6 +120,10 @@ StringRef Availability::getClass() const {
return parentClass.front()->getName();
}
StringRef Availability::getInterfaceClassNamespace() const {
return def->getValueAsString("cppNamespace");
}
StringRef Availability::getInterfaceClassName() const {
return def->getValueAsString("interfaceName");
}
@ -168,9 +180,16 @@ std::vector<Availability> getAvailabilities(const Record &def) {
static void emitInterfaceDef(const Availability &availability,
raw_ostream &os) {
os << availability.getQueryFnRetType() << " ";
StringRef cppNamespace = availability.getInterfaceClassNamespace();
cppNamespace.consume_front("::");
if (!cppNamespace.empty())
os << cppNamespace << "::";
StringRef methodName = availability.getQueryFnName();
os << availability.getQueryFnRetType() << " "
<< availability.getInterfaceClassName() << "::" << methodName << "() {\n"
os << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
<< " return getImpl()->" << methodName << "(getImpl(), getOperation());\n"
<< "}\n";
}
@ -237,13 +256,16 @@ static void emitInterfaceDecl(const Availability &availability,
std::string interfaceTraitsName =
std::string(formatv("{0}Traits", interfaceName));
StringRef cppNamespace = availability.getInterfaceClassNamespace();
NamespaceEmitter nsEmitter(os, cppNamespace);
// Emit the traits struct containing the concept and model declarations.
os << "namespace detail {\n"
<< "struct " << interfaceTraitsName << " {\n";
emitConceptDecl(availability, os);
os << '\n';
emitModelDecl(availability, os);
os << "};\n} // end namespace detail\n\n";
os << "};\n} // namespace detail\n\n";
// Emit the main interface class declaration.
os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";