[mlir][spirv] Add basic definitions for supporting availability

SPIR-V has a few mechanisms to control op availability: version,
extension, and capabilities. These mechanisms are considered as
different availability classes.

This commit introduces basic definitions for modelling SPIR-V
availability classes. Specifically, an `Availability` class is
added to SPIRVBase.td, along with two subclasses: MinVersion
and MaxVersion for versioning. SPV_Op is extended to take a
list of `Availability`. Each `Availability` instance carries
information for generating op interfaces for the corresponding
availability class and also the concrete availability
requirements.

With the availability spec on ops, we can now auto-generate the
op interfaces of all SPIR-V availability classes and also
synthesize the op's implementations of these interfaces. The
interface generation is done via new TableGen backends
-gen-avail-interface-{decls|defs}. The op's implementation is
done via -gen-spirv-avail-impls.

Differential Revision: https://reviews.llvm.org/D71930
This commit is contained in:
Lei Zhang 2019-12-27 16:24:33 -05:00
parent c3dbd782f1
commit b30d87a90b
15 changed files with 728 additions and 23 deletions

View File

@ -1,8 +1,3 @@
set(LLVM_TARGET_DEFINITIONS SPIRVLowering.td)
mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls)
mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIRSPIRVLoweringStructGen)
add_mlir_dialect(SPIRVOps SPIRVOps)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
@ -10,6 +5,12 @@ mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
mlir_tablegen(SPIRVAvailability.h.inc -gen-avail-interface-decls)
mlir_tablegen(SPIRVAvailability.cpp.inc -gen-avail-interface-defs)
mlir_tablegen(SPIRVOpAvailabilityImpl.inc -gen-spirv-avail-impls)
add_public_tablegen_target(MLIRSPIRVAvailabilityIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)
add_public_tablegen_target(MLIRSPIRVSerializationGen)
@ -17,3 +18,8 @@ add_public_tablegen_target(MLIRSPIRVSerializationGen)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVOpUtils.inc -gen-spirv-op-utils)
add_public_tablegen_target(MLIRSPIRVOpUtilsGen)
set(LLVM_TARGET_DEFINITIONS SPIRVLowering.td)
mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls)
mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIRSPIRVLoweringStructGen)

View File

@ -120,6 +120,13 @@ def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> {
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_3>,
Extension<[]>,
Capability<[SPV_C_Kernel]>
];
let arguments = (ins
SPV_AnyPtr:$pointer,
SPV_ScopeAttr:$memory_scope,

View File

@ -0,0 +1,86 @@
//===- SPIRVAvailability.td - Op Availability Base file ----*- tablegen -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef SPIRV_AVAILABILITY
#define SPIRV_AVAILABILITY
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// Op availaility definitions
//===----------------------------------------------------------------------===//
// The base class for defining op availability dimensions.
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
// The name for the generated C++ OpInterface subclass.
string interfaceName = ?;
// The documentation for the generated C++ OpInterface subclass.
string interfaceDescription = "";
// The following are fields for controlling the query function signature.
// The query function's return type in the generated C++ OpInterface subclass.
string queryFnRetType = ?;
// The query function's name in the generated C++ OpInterface subclass.
string queryFnName = ?;
// The following are fields for controlling the query function implementation.
// The logic for merging two availability requirements. This is used to derive
// the final availability requirement when, for example, an op has two
// operands and these two operands have different availability requirements.
//
// The code should use `$overall` as the placeholder for the final requirement
// and `$instance` for the current availability requirement instance.
code mergeAction = ?;
// The initializer for the final availability requirement.
string initializer = ?;
// An availability instance's type.
string instanceType = ?;
// The following are fields for a concrete availability instance.
// The availability requirement carried by a concrete instance.
string instance = ?;
}
class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
: Availability {
let interfaceName = name;
let queryFnRetType = scheme.returnType;
let queryFnName = "getMinVersion";
let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
"std::max($overall, $instance))";
let initializer = "static_cast<" # scheme.returnType # ">(uint32_t(0))";
let instanceType = scheme.cppNamespace # "::" # scheme.className;
let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
min.symbol;
}
class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
: Availability {
let interfaceName = name;
let queryFnRetType = scheme.returnType;
let queryFnName = "getMaxVersion";
let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
"std::min($overall, $instance))";
let initializer = "static_cast<" # scheme.returnType # ">(~uint32_t(0))";
let instanceType = scheme.cppNamespace # "::" # scheme.className;
let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
max.symbol;
}
#endif // SPIRV_AVAILABILITY

View File

@ -16,6 +16,7 @@
#define SPIRV_BASE
include "mlir/IR/OpBase.td"
include "mlir/Dialect/SPIRV/SPIRVAvailability.td"
//===----------------------------------------------------------------------===//
// SPIR-V dialect definitions
@ -45,6 +46,142 @@ def SPV_Dialect : Dialect {
let cppNamespace = "spirv";
}
//===----------------------------------------------------------------------===//
// SPIR-V availability definitions
//===----------------------------------------------------------------------===//
def SPV_V_1_0 : I32EnumAttrCase<"V_1_0", 0>;
def SPV_V_1_1 : I32EnumAttrCase<"V_1_1", 1>;
def SPV_V_1_2 : I32EnumAttrCase<"V_1_2", 2>;
def SPV_V_1_3 : I32EnumAttrCase<"V_1_3", 3>;
def SPV_V_1_4 : I32EnumAttrCase<"V_1_4", 4>;
def SPV_V_1_5 : I32EnumAttrCase<"V_1_5", 5>;
def SPV_VersionAttr : I32EnumAttr<"Version", "valid SPIR-V version", [
SPV_V_1_0, SPV_V_1_1, SPV_V_1_2, SPV_V_1_3, SPV_V_1_4, SPV_V_1_5]> {
let cppNamespace = "::mlir::spirv";
}
class MinVersion<I32EnumAttrCase min> : MinVersionBase<
"QueryMinVersionInterface", SPV_VersionAttr, min> {
let interfaceDescription = [{
Querying interface for minimal required SPIR-V version.
This interface provides a `getMinVersion()` method to query the minimal
required version for the implementing SPIR-V operation. The returned value
is a `mlir::spirv::Version` enumerant.
}];
}
class MaxVersion<I32EnumAttrCase max> : MaxVersionBase<
"QueryMaxVersionInterface", SPV_VersionAttr, max> {
let interfaceDescription = [{
Querying interface for maximal supported SPIR-V version.
This interface provides a `getMaxVersion()` method to query the maximal
supported version for the implementing SPIR-V operation. The returned value
is a `mlir::spirv::Version` enumerant.
}];
}
class Extension<list<StrEnumAttrCase> extensions> : Availability {
let interfaceName = "QueryExtensionInterface";
let interfaceDescription = [{
Querying interface for required SPIR-V extensions.
This interface provides a `getExtensions()` method to query the required
extensions for the implementing SPIR-V operation. The returned value
is a nested vector whose element is `mlir::spirv::Extension`s. The outer
vector's elements (which are vectors) should be interpreted as conjunction
while the innner vector's elements (which are `mlir::spirv::Extension`s)
should be interpreted as disjunction. For example, given
```
{{Extension::A, Extension::B}, {Extension::C}, {{Extension::D, Extension::E}}
```
The operation instance is available when (`Extension::A` OR `Extension::B`)
AND (`Extension::C`) AND (`Extension::D` OR `Extension::E`) is enabled.
}];
// TODO(antiagainst): Using SmallVector<SmallVector<...>> is an anti-pattern.
// Find a better way for this.
let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
"::mlir::spirv::Extension, 1>, 1>";
let queryFnName = "getExtensions";
let mergeAction = !if(
!empty(extensions), "", "$overall.emplace_back($instance)");
let initializer = "{}";
let instanceType = "::llvm::SmallVector<::mlir::spirv::Extension, 1>";
// Compose all capabilities as an C++ initializer list
let instance = "std::initializer_list<::mlir::spirv::Extension>{" #
StrJoin<!foreach(
ext, extensions,
"::mlir::spirv::Extension::" # ext.symbol)>.result #
"}";
}
class Capability<list<I32EnumAttrCase> capabilities> : Availability {
let interfaceName = "QueryCapabilityInterface";
let interfaceDescription = [{
Querying interface for required SPIR-V capabilities.
This interface provides a `getCapabilities()` method to query the required
capabilities for the implementing SPIR-V operation. The returned value
is a neted vector whose element is `mlir::spirv::Capability`s. The outer
vector's elements (which are vectors) should be interpreted as conjunction
while the innner vector's elements (which are `mlir::spirv::Capability`s)
should be interpreted as disjunction. For example, given
```
{{Capability::A, Capability::B}, {Capability::C}, {{Capability::D, Capability::E}}
```
The operation instance is available when (`Capability::A` OR `Capability::B`)
AND (`Capability::C`) AND (`Capability::D` OR `Capability::E`) is enabled.
}];
let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
"::mlir::spirv::Capability, 1>, 1>";
let queryFnName = "getCapabilities";
let mergeAction = !if(
!empty(capabilities), "", "$overall.emplace_back($instance)");
let initializer = "{}";
let instanceType = "::llvm::SmallVector<::mlir::spirv::Capability, 1>";
// Compose all capabilities as an C++ initializer list
let instance = "std::initializer_list<::mlir::spirv::Capability>{" #
StrJoin<!foreach(
cap, capabilities,
"::mlir::spirv::Capability::" # cap.symbol)>.result #
"}";
}
// TODO(antiagainst): the following interfaces definitions are duplicating with
// the above. Remove them once we are able to support dialect-specific contents
// in ODS.
def QueryMinVersionInterface : OpInterface<"QueryMinVersionInterface"> {
let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMinVersion">];
}
def QueryMaxVersionInterface : OpInterface<"QueryMaxVersionInterface"> {
let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMaxVersion">];
}
def QueryExtensionInterface : OpInterface<"QueryExtensionInterface"> {
let methods = [InterfaceMethod<
"",
"::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Extension, 1>, 1>",
"getExtensions">];
}
def QueryCapabilityInterface : OpInterface<"QueryCapabilityInterface"> {
let methods = [InterfaceMethod<
"",
"::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Capability, 1>, 1>",
"getCapabilities">];
}
//===----------------------------------------------------------------------===//
// SPIR-V extension definitions
//===----------------------------------------------------------------------===//
@ -1216,7 +1353,22 @@ def SPV_OpcodeAttr :
// Base class for all SPIR-V ops.
class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
Op<SPV_Dialect, mnemonic, traits> {
Op<SPV_Dialect, mnemonic, !listconcat(traits, [
// TODO(antiagainst): We don't need all of the following traits for
// every op; only the suitabble ones should be added automatically
// after ODS supports dialect-specific contents.
DeclareOpInterfaceMethods<QueryMinVersionInterface>,
DeclareOpInterfaceMethods<QueryMaxVersionInterface>,
DeclareOpInterfaceMethods<QueryExtensionInterface>,
DeclareOpInterfaceMethods<QueryCapabilityInterface>
])> {
// Availability specification for this op itself.
list<Availability> availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[]>
];
// For each SPIR-V op, the following static functions need to be defined
// in SPVOps.cpp:

View File

@ -53,6 +53,13 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
```
}];
let availability = [
MinVersion<SPV_V_1_3>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_GroupNonUniformBallot]>
];
let arguments = (ins
SPV_ScopeAttr:$execution_scope,
SPV_Bool:$predicate

View File

@ -21,18 +21,23 @@ class OpBuilder;
namespace spirv {
// TableGen'erated operation interfaces for querying versions, extensions, and
// capabilities.
#include "mlir/Dialect/SPIRV/SPIRVAvailability.h.inc"
// TablenGen'erated operation declarations.
#define GET_OP_CLASSES
#include "mlir/Dialect/SPIRV/SPIRVOps.h.inc"
/// Following methods are auto-generated.
///
/// Get the name used in the Op to refer to an enum value of the given
/// `EnumClass`.
/// template <typename EnumClass> StringRef attributeName();
///
/// Get the function that can be used to symbolize an enum value.
/// template <typename EnumClass>
/// Optional<EnumClass> (*)(StringRef) symbolizeEnum();
// TableGen'erated helper functions.
//
// Get the name used in the Op to refer to an enum value of the given
// `EnumClass`.
// template <typename EnumClass> StringRef attributeName();
//
// Get the function that can be used to symbolize an enum value.
// template <typename EnumClass>
// Optional<EnumClass> (*)(StringRef) symbolizeEnum();
#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
} // end namespace spirv

View File

@ -15,6 +15,7 @@ add_llvm_library(MLIRSPIRV
)
add_dependencies(MLIRSPIRV
MLIRSPIRVAvailabilityIncGen
MLIRSPIRVCanonicalizationIncGen
MLIRSPIRVEnumsIncGen
MLIRSPIRVLoweringStructGen

View File

@ -3063,8 +3063,16 @@ static LogicalResult verify(spirv::VariableOp varOp) {
namespace mlir {
namespace spirv {
// TableGen'erated operation interfaces for querying versions, extensions, and
// capabilities.
#include "mlir/Dialect/SPIRV/SPIRVAvailability.cpp.inc"
// TablenGen'erated operation definitions.
#define GET_OP_CLASSES
#include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
// TableGen'erated operation availability interface implementations.
#include "mlir/Dialect/SPIRV/SPIRVOpAvailabilityImpl.inc"
} // namespace spirv
} // namespace mlir

View File

@ -1,3 +1,4 @@
add_subdirectory(Dialect)
add_subdirectory(EDSC)
add_subdirectory(mlir-cpu-runner)
add_subdirectory(SDBM)

View File

@ -0,0 +1 @@
add_subdirectory(SPIRV)

View File

@ -0,0 +1,14 @@
add_llvm_library(MLIRSPIRVTestPasses
TestAvailability.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
)
target_link_libraries(MLIRSPIRVTestPasses
MLIRIR
MLIRPass
MLIRSPIRV
MLIRSupport
)

View File

@ -0,0 +1,73 @@
//===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Function.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
/// A pass for testing SPIR-V op availability.
struct TestAvailability : public FunctionPass<TestAvailability> {
void runOnFunction() override;
};
} // end anonymous namespace
void TestAvailability::runOnFunction() {
auto f = getFunction();
llvm::outs() << f.getName() << "\n";
Dialect *spvDialect = getContext().getRegisteredDialect("spv");
f.getOperation()->walk([&](Operation *op) {
if (op->getDialect() != spvDialect)
return WalkResult::advance();
auto &os = llvm::outs();
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
os << " min version: "
<< spirv::stringifyVersion(minVersion.getMinVersion()) << "\n";
if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
os << " max version: "
<< spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n";
if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
os << " extensions: [";
for (const auto &exts : extension.getExtensions()) {
os << " [";
interleaveComma(exts, os, [&](spirv::Extension ext) {
os << spirv::stringifyExtension(ext);
});
os << "]";
}
os << " ]\n";
}
if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
os << " capabilities: [";
for (const auto &caps : capability.getCapabilities()) {
os << " [";
interleaveComma(caps, os, [&](spirv::Capability cap) {
os << spirv::stringifyCapability(cap);
});
os << "]";
}
os << " ]\n";
}
os.flush();
return WalkResult::advance();
});
}
static PassRegistration<TestAvailability> pass("test-spirv-op-availability",
"Test SPIR-V op availability");

View File

@ -0,0 +1,31 @@
// RUN: mlir-opt -disable-pass-threading -test-spirv-op-availability %s | FileCheck %s
// CHECK-LABEL: iadd
func @iadd(%arg: i32) -> i32 {
// CHECK: min version: V_1_0
// CHECK: max version: V_1_5
// CHECK: extensions: [ ]
// CHECK: capabilities: [ ]
%0 = spv.IAdd %arg, %arg: i32
return %0: i32
}
// CHECK: atomic_compare_exchange_weak
func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 {
// CHECK: min version: V_1_0
// CHECK: max version: V_1_3
// CHECK: extensions: [ ]
// CHECK: capabilities: [ [Kernel] ]
%0 = spv.AtomicCompareExchangeWeak "Workgroup" "Release" "Acquire" %ptr, %value, %comparator: !spv.ptr<i32, Workgroup>
return %0: i32
}
// CHECK-LABEL: subgroup_ballot
func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
// CHECK: min version: V_1_3
// CHECK: max version: V_1_5
// CHECK: extensions: [ ]
// CHECK: capabilities: [ [GroupNonUniformBallot] ]
%0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
return %0: vector<4xi32>
}

View File

@ -41,6 +41,7 @@ set(LIBS
MLIRROCDLIR
MLIRSPIRV
MLIRStandardToSPIRVTransforms
MLIRSPIRVTestPasses
MLIRSPIRVTransforms
MLIRStandardOps
MLIRStandardToLLVM

View File

@ -13,6 +13,7 @@
#include "mlir/Support/StringExtras.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/Sequence.h"
@ -43,6 +44,233 @@ using mlir::tblgen::NamedAttribute;
using mlir::tblgen::NamedTypeConstraint;
using mlir::tblgen::Operator;
//===----------------------------------------------------------------------===//
// Availability Wrapper Class
//===----------------------------------------------------------------------===//
namespace {
// Wrapper class with helper methods for accessing availability defined in
// TableGen.
class Availability {
public:
explicit Availability(const Record *def);
// Returns the name of the direct TableGen class for this availability
// instance.
StringRef getClass() const;
// Returns the generated C++ interface's class name.
StringRef getInterfaceClassName() const;
// Returns the generated C++ interface's description.
StringRef getInterfaceDescription() const;
// Returns the name of the query function insided the generated C++ interface.
StringRef getQueryFnName() const;
// Returns the return type of the query function insided the generated C++
// interface.
StringRef getQueryFnRetType() const;
// Returns the code for merging availability requirements.
StringRef getMergeActionCode() const;
// Returns the initializer expression for initializing the final availability
// requirements.
StringRef getMergeInitializer() const;
// Returns the C++ type for an availability instance.
StringRef getMergeInstanceType() const;
// Returns the concrete availability instance carried in this case.
StringRef getMergeInstance() const;
private:
// The TableGen definition of this availability.
const llvm::Record *def;
};
} // namespace
Availability::Availability(const llvm::Record *def) : def(def) {
assert(def->isSubClassOf("Availability") &&
"must be subclass of TableGen 'Availability' class");
}
StringRef Availability::getClass() const {
SmallVector<Record *, 1> parentClass;
def->getDirectSuperClasses(parentClass);
if (parentClass.size() != 1) {
PrintFatalError(def->getLoc(),
"expected to only have one direct superclass");
}
return parentClass.front()->getName();
}
StringRef Availability::getInterfaceClassName() const {
return def->getValueAsString("interfaceName");
}
StringRef Availability::getInterfaceDescription() const {
return def->getValueAsString("interfaceDescription");
}
StringRef Availability::getQueryFnRetType() const {
return def->getValueAsString("queryFnRetType");
}
StringRef Availability::getQueryFnName() const {
return def->getValueAsString("queryFnName");
}
StringRef Availability::getMergeActionCode() const {
return def->getValueAsString("mergeAction");
}
StringRef Availability::getMergeInitializer() const {
return def->getValueAsString("initializer");
}
StringRef Availability::getMergeInstanceType() const {
return def->getValueAsString("instanceType");
}
StringRef Availability::getMergeInstance() const {
return def->getValueAsString("instance");
}
//===----------------------------------------------------------------------===//
// Availability Interface Definitions AutoGen
//===----------------------------------------------------------------------===//
static void emitInterfaceDef(const Availability &availability,
raw_ostream &os) {
StringRef methodName = availability.getQueryFnName();
os << availability.getQueryFnRetType() << " "
<< availability.getInterfaceClassName() << "::" << methodName << "() {\n"
<< " return getImpl()->" << methodName << "(getOperation());\n"
<< "}\n";
}
static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("Availability Interface Definitions", os);
auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
SmallVector<const Record *, 1> handledClasses;
for (const Record *def : defs) {
SmallVector<Record *, 1> parent;
def->getDirectSuperClasses(parent);
if (parent.size() != 1) {
PrintFatalError(def->getLoc(),
"expected to only have one direct superclass");
}
if (llvm::is_contained(handledClasses, parent.front()))
continue;
Availability availability(def);
emitInterfaceDef(availability, os);
handledClasses.push_back(parent.front());
}
return false;
}
//===----------------------------------------------------------------------===//
// Availability Interface Declarations AutoGen
//===----------------------------------------------------------------------===//
static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
os << " class Concept {\n"
<< " public:\n"
<< " virtual ~Concept() = default;\n"
<< " virtual " << availability.getQueryFnRetType() << " "
<< availability.getQueryFnName() << "(Operation *tblgen_opaque_op) = 0;\n"
<< " };\n";
}
static void emitModelDecl(const Availability &availability, raw_ostream &os) {
os << " template<typename ConcreteOp>\n";
os << " class Model : public Concept {\n"
<< " public:\n"
<< " " << availability.getQueryFnRetType() << " "
<< availability.getQueryFnName()
<< "(Operation *tblgen_opaque_op) final {\n"
<< " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
<< " (void)op;\n"
// Forward to the method on the concrete operation type.
<< " return op." << availability.getQueryFnName() << "();\n"
<< " }\n"
<< " };\n";
}
static void emitInterfaceDecl(const Availability &availability,
raw_ostream &os) {
StringRef interfaceName = availability.getInterfaceClassName();
std::string interfaceTraitsName = formatv("{0}Traits", interfaceName);
// Emit the traits struct containing the concept and model declarations.
os << "namespace detail {\n"
<< "struct " << interfaceTraitsName << " {\n";
emitConceptDecl(availability, os);
os << '\n';
emitModelDecl(availability, os);
os << "};\n} // end namespace detail\n\n";
// Emit the main interface class declaration.
os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";
os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
"public:\n"
" using OpInterface<{1}, detail::{2}>::OpInterface;\n",
interfaceName, interfaceName, interfaceTraitsName);
// Emit query function declaration.
os << " " << availability.getQueryFnRetType() << " "
<< availability.getQueryFnName() << "();\n";
os << "};\n\n";
}
static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("Availability Interface Declarations", os);
auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
SmallVector<const Record *, 4> handledClasses;
for (const Record *def : defs) {
SmallVector<Record *, 1> parent;
def->getDirectSuperClasses(parent);
if (parent.size() != 1) {
PrintFatalError(def->getLoc(),
"expected to only have one direct superclass");
}
if (llvm::is_contained(handledClasses, parent.front()))
continue;
Availability avail(def);
emitInterfaceDecl(avail, os);
handledClasses.push_back(parent.front());
}
return false;
}
//===----------------------------------------------------------------------===//
// Availability Interface Hook Registration
//===----------------------------------------------------------------------===//
// Registers the operation interface generator to mlir-tblgen.
static mlir::GenRegistration
genInterfaceDecls("gen-avail-interface-decls",
"Generate availability interface declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitInterfaceDecls(records, os);
});
// Registers the operation interface generator to mlir-tblgen.
static mlir::GenRegistration
genInterfaceDefs("gen-avail-interface-defs",
"Generate op interface definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitInterfaceDefs(records, os);
});
//===----------------------------------------------------------------------===//
// Serialization AutoGen
//===----------------------------------------------------------------------===//
@ -650,6 +878,17 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper,
return false;
}
//===----------------------------------------------------------------------===//
// Serialization Hook Registration
//===----------------------------------------------------------------------===//
static mlir::GenRegistration genSerialization(
"gen-spirv-serialization",
"Generate SPIR-V (de)serialization utilities and functions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitSerializationFns(records, os);
});
//===----------------------------------------------------------------------===//
// Op Utils AutoGen
//===----------------------------------------------------------------------===//
@ -707,19 +946,92 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
}
//===----------------------------------------------------------------------===//
// Hook Registration
// Op Utils Hook Registration
//===----------------------------------------------------------------------===//
static mlir::GenRegistration genSerialization(
"gen-spirv-serialization",
"Generate SPIR-V (de)serialization utilities and functions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitSerializationFns(records, os);
});
static mlir::GenRegistration
genOpUtils("gen-spirv-op-utils",
"Generate SPIR-V operation utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitOpUtils(records, os);
});
//===----------------------------------------------------------------------===//
// SPIR-V Availability Impl AutoGen
//===----------------------------------------------------------------------===//
// Returns the availability spec of the given `def`.
std::vector<Availability> getAvailabilities(const Record &def) {
std::vector<Availability> availabilities;
if (auto *availListInit = def.getValueAsListInit("availability")) {
availabilities.reserve(availListInit->size());
for (auto *availInit : *availListInit)
availabilities.emplace_back(
llvm::cast<llvm::DefInit>(availInit)->getDef());
}
return availabilities;
}
static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
mlir::tblgen::FmtContext fctx;
fctx.addSubst("overall", "overall");
std::vector<Availability> opAvailabilities =
getAvailabilities(srcOp.getDef());
// First collect all availablity classes this op should implement.
// All availablity instances keep information for the generated interface and
// the instance's specific requirement. Here we remember a random instance so
// we can get the information regarding the generated interface.
llvm::StringMap<Availability> availClasses;
for (const Availability &avail : opAvailabilities)
availClasses.try_emplace(avail.getClass(), avail);
// Then generate implementation for each availability class.
for (const auto &availClass : availClasses) {
StringRef availClassName = availClass.getKey();
Availability avail = availClass.getValue();
// Generate the implementation method signature.
os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(),
srcOp.getCppClassName(), avail.getQueryFnName());
// Create the variable for the final requirement and initialize it.
os << formatv(" {0} overall = {1};\n", avail.getQueryFnRetType(),
avail.getMergeInitializer());
// Update with the op's specific availability spec.
for (const Availability &avail : opAvailabilities)
if (avail.getClass() == availClassName) {
os << " "
<< tgfmt(avail.getMergeActionCode(),
&fctx.addSubst("instance", avail.getMergeInstance()))
<< ";\n";
}
os << " return overall;\n";
os << "}\n";
}
}
static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os);
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
for (const auto *def : defs) {
Operator op(def);
emitAvailabilityImpl(op, os);
}
return false;
}
//===----------------------------------------------------------------------===//
// Op Availability Implementation Hook Registration
//===----------------------------------------------------------------------===//
static mlir::GenRegistration
genOpAvailabilityImpl("gen-spirv-avail-impls",
"Generate SPIR-V operation utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitAvailabilityImpl(records, os);
});