forked from OSchip/llvm-project
[mlir][spirv] Add basic definitions for supporting availability
SPIR-V has a few mechanisms to control op availability: version, extension, and capabilities. These mechanisms are considered as different availability classes. This commit introduces basic definitions for modelling SPIR-V availability classes. Specifically, an `Availability` class is added to SPIRVBase.td, along with two subclasses: MinVersion and MaxVersion for versioning. SPV_Op is extended to take a list of `Availability`. Each `Availability` instance carries information for generating op interfaces for the corresponding availability class and also the concrete availability requirements. With the availability spec on ops, we can now auto-generate the op interfaces of all SPIR-V availability classes and also synthesize the op's implementations of these interfaces. The interface generation is done via new TableGen backends -gen-avail-interface-{decls|defs}. The op's implementation is done via -gen-spirv-avail-impls. Differential Revision: https://reviews.llvm.org/D71930
This commit is contained in:
parent
c3dbd782f1
commit
b30d87a90b
|
@ -1,8 +1,3 @@
|
|||
set(LLVM_TARGET_DEFINITIONS SPIRVLowering.td)
|
||||
mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls)
|
||||
mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs)
|
||||
add_public_tablegen_target(MLIRSPIRVLoweringStructGen)
|
||||
|
||||
add_mlir_dialect(SPIRVOps SPIRVOps)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
|
||||
|
@ -10,6 +5,12 @@ mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
|
|||
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
|
||||
mlir_tablegen(SPIRVAvailability.h.inc -gen-avail-interface-decls)
|
||||
mlir_tablegen(SPIRVAvailability.cpp.inc -gen-avail-interface-defs)
|
||||
mlir_tablegen(SPIRVOpAvailabilityImpl.inc -gen-spirv-avail-impls)
|
||||
add_public_tablegen_target(MLIRSPIRVAvailabilityIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
|
||||
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)
|
||||
add_public_tablegen_target(MLIRSPIRVSerializationGen)
|
||||
|
@ -17,3 +18,8 @@ add_public_tablegen_target(MLIRSPIRVSerializationGen)
|
|||
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
|
||||
mlir_tablegen(SPIRVOpUtils.inc -gen-spirv-op-utils)
|
||||
add_public_tablegen_target(MLIRSPIRVOpUtilsGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS SPIRVLowering.td)
|
||||
mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls)
|
||||
mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs)
|
||||
add_public_tablegen_target(MLIRSPIRVLoweringStructGen)
|
||||
|
|
|
@ -120,6 +120,13 @@ def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> {
|
|||
```
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_3>,
|
||||
Extension<[]>,
|
||||
Capability<[SPV_C_Kernel]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_AnyPtr:$pointer,
|
||||
SPV_ScopeAttr:$memory_scope,
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
//===- SPIRVAvailability.td - Op Availability Base file ----*- tablegen -*-===//
|
||||
//
|
||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef SPIRV_AVAILABILITY
|
||||
#define SPIRV_AVAILABILITY
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op availaility definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// The base class for defining op availability dimensions.
|
||||
class Availability {
|
||||
// The following are fields for controlling the generated C++ OpInterface.
|
||||
|
||||
// The name for the generated C++ OpInterface subclass.
|
||||
string interfaceName = ?;
|
||||
// The documentation for the generated C++ OpInterface subclass.
|
||||
string interfaceDescription = "";
|
||||
|
||||
// The following are fields for controlling the query function signature.
|
||||
|
||||
// The query function's return type in the generated C++ OpInterface subclass.
|
||||
string queryFnRetType = ?;
|
||||
// The query function's name in the generated C++ OpInterface subclass.
|
||||
string queryFnName = ?;
|
||||
|
||||
// The following are fields for controlling the query function implementation.
|
||||
|
||||
// The logic for merging two availability requirements. This is used to derive
|
||||
// the final availability requirement when, for example, an op has two
|
||||
// operands and these two operands have different availability requirements.
|
||||
//
|
||||
// The code should use `$overall` as the placeholder for the final requirement
|
||||
// and `$instance` for the current availability requirement instance.
|
||||
code mergeAction = ?;
|
||||
// The initializer for the final availability requirement.
|
||||
string initializer = ?;
|
||||
// An availability instance's type.
|
||||
string instanceType = ?;
|
||||
|
||||
// The following are fields for a concrete availability instance.
|
||||
|
||||
// The availability requirement carried by a concrete instance.
|
||||
string instance = ?;
|
||||
}
|
||||
|
||||
class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
|
||||
: Availability {
|
||||
let interfaceName = name;
|
||||
|
||||
let queryFnRetType = scheme.returnType;
|
||||
let queryFnName = "getMinVersion";
|
||||
|
||||
let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
|
||||
"std::max($overall, $instance))";
|
||||
let initializer = "static_cast<" # scheme.returnType # ">(uint32_t(0))";
|
||||
let instanceType = scheme.cppNamespace # "::" # scheme.className;
|
||||
|
||||
let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
|
||||
min.symbol;
|
||||
}
|
||||
|
||||
class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
|
||||
: Availability {
|
||||
let interfaceName = name;
|
||||
|
||||
let queryFnRetType = scheme.returnType;
|
||||
let queryFnName = "getMaxVersion";
|
||||
|
||||
let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
|
||||
"std::min($overall, $instance))";
|
||||
let initializer = "static_cast<" # scheme.returnType # ">(~uint32_t(0))";
|
||||
let instanceType = scheme.cppNamespace # "::" # scheme.className;
|
||||
|
||||
let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
|
||||
max.symbol;
|
||||
}
|
||||
|
||||
#endif // SPIRV_AVAILABILITY
|
|
@ -16,6 +16,7 @@
|
|||
#define SPIRV_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVAvailability.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V dialect definitions
|
||||
|
@ -45,6 +46,142 @@ def SPV_Dialect : Dialect {
|
|||
let cppNamespace = "spirv";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V availability definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SPV_V_1_0 : I32EnumAttrCase<"V_1_0", 0>;
|
||||
def SPV_V_1_1 : I32EnumAttrCase<"V_1_1", 1>;
|
||||
def SPV_V_1_2 : I32EnumAttrCase<"V_1_2", 2>;
|
||||
def SPV_V_1_3 : I32EnumAttrCase<"V_1_3", 3>;
|
||||
def SPV_V_1_4 : I32EnumAttrCase<"V_1_4", 4>;
|
||||
def SPV_V_1_5 : I32EnumAttrCase<"V_1_5", 5>;
|
||||
|
||||
def SPV_VersionAttr : I32EnumAttr<"Version", "valid SPIR-V version", [
|
||||
SPV_V_1_0, SPV_V_1_1, SPV_V_1_2, SPV_V_1_3, SPV_V_1_4, SPV_V_1_5]> {
|
||||
let cppNamespace = "::mlir::spirv";
|
||||
}
|
||||
|
||||
class MinVersion<I32EnumAttrCase min> : MinVersionBase<
|
||||
"QueryMinVersionInterface", SPV_VersionAttr, min> {
|
||||
let interfaceDescription = [{
|
||||
Querying interface for minimal required SPIR-V version.
|
||||
|
||||
This interface provides a `getMinVersion()` method to query the minimal
|
||||
required version for the implementing SPIR-V operation. The returned value
|
||||
is a `mlir::spirv::Version` enumerant.
|
||||
}];
|
||||
}
|
||||
|
||||
class MaxVersion<I32EnumAttrCase max> : MaxVersionBase<
|
||||
"QueryMaxVersionInterface", SPV_VersionAttr, max> {
|
||||
let interfaceDescription = [{
|
||||
Querying interface for maximal supported SPIR-V version.
|
||||
|
||||
This interface provides a `getMaxVersion()` method to query the maximal
|
||||
supported version for the implementing SPIR-V operation. The returned value
|
||||
is a `mlir::spirv::Version` enumerant.
|
||||
}];
|
||||
}
|
||||
|
||||
class Extension<list<StrEnumAttrCase> extensions> : Availability {
|
||||
let interfaceName = "QueryExtensionInterface";
|
||||
let interfaceDescription = [{
|
||||
Querying interface for required SPIR-V extensions.
|
||||
|
||||
This interface provides a `getExtensions()` method to query the required
|
||||
extensions for the implementing SPIR-V operation. The returned value
|
||||
is a nested vector whose element is `mlir::spirv::Extension`s. The outer
|
||||
vector's elements (which are vectors) should be interpreted as conjunction
|
||||
while the innner vector's elements (which are `mlir::spirv::Extension`s)
|
||||
should be interpreted as disjunction. For example, given
|
||||
|
||||
```
|
||||
{{Extension::A, Extension::B}, {Extension::C}, {{Extension::D, Extension::E}}
|
||||
```
|
||||
|
||||
The operation instance is available when (`Extension::A` OR `Extension::B`)
|
||||
AND (`Extension::C`) AND (`Extension::D` OR `Extension::E`) is enabled.
|
||||
}];
|
||||
|
||||
// TODO(antiagainst): Using SmallVector<SmallVector<...>> is an anti-pattern.
|
||||
// Find a better way for this.
|
||||
let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
|
||||
"::mlir::spirv::Extension, 1>, 1>";
|
||||
let queryFnName = "getExtensions";
|
||||
|
||||
let mergeAction = !if(
|
||||
!empty(extensions), "", "$overall.emplace_back($instance)");
|
||||
let initializer = "{}";
|
||||
let instanceType = "::llvm::SmallVector<::mlir::spirv::Extension, 1>";
|
||||
|
||||
// Compose all capabilities as an C++ initializer list
|
||||
let instance = "std::initializer_list<::mlir::spirv::Extension>{" #
|
||||
StrJoin<!foreach(
|
||||
ext, extensions,
|
||||
"::mlir::spirv::Extension::" # ext.symbol)>.result #
|
||||
"}";
|
||||
}
|
||||
|
||||
class Capability<list<I32EnumAttrCase> capabilities> : Availability {
|
||||
let interfaceName = "QueryCapabilityInterface";
|
||||
let interfaceDescription = [{
|
||||
Querying interface for required SPIR-V capabilities.
|
||||
|
||||
This interface provides a `getCapabilities()` method to query the required
|
||||
capabilities for the implementing SPIR-V operation. The returned value
|
||||
is a neted vector whose element is `mlir::spirv::Capability`s. The outer
|
||||
vector's elements (which are vectors) should be interpreted as conjunction
|
||||
while the innner vector's elements (which are `mlir::spirv::Capability`s)
|
||||
should be interpreted as disjunction. For example, given
|
||||
|
||||
```
|
||||
{{Capability::A, Capability::B}, {Capability::C}, {{Capability::D, Capability::E}}
|
||||
```
|
||||
|
||||
The operation instance is available when (`Capability::A` OR `Capability::B`)
|
||||
AND (`Capability::C`) AND (`Capability::D` OR `Capability::E`) is enabled.
|
||||
}];
|
||||
|
||||
let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
|
||||
"::mlir::spirv::Capability, 1>, 1>";
|
||||
let queryFnName = "getCapabilities";
|
||||
|
||||
let mergeAction = !if(
|
||||
!empty(capabilities), "", "$overall.emplace_back($instance)");
|
||||
let initializer = "{}";
|
||||
let instanceType = "::llvm::SmallVector<::mlir::spirv::Capability, 1>";
|
||||
|
||||
// Compose all capabilities as an C++ initializer list
|
||||
let instance = "std::initializer_list<::mlir::spirv::Capability>{" #
|
||||
StrJoin<!foreach(
|
||||
cap, capabilities,
|
||||
"::mlir::spirv::Capability::" # cap.symbol)>.result #
|
||||
"}";
|
||||
}
|
||||
|
||||
// TODO(antiagainst): the following interfaces definitions are duplicating with
|
||||
// the above. Remove them once we are able to support dialect-specific contents
|
||||
// in ODS.
|
||||
def QueryMinVersionInterface : OpInterface<"QueryMinVersionInterface"> {
|
||||
let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMinVersion">];
|
||||
}
|
||||
def QueryMaxVersionInterface : OpInterface<"QueryMaxVersionInterface"> {
|
||||
let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMaxVersion">];
|
||||
}
|
||||
def QueryExtensionInterface : OpInterface<"QueryExtensionInterface"> {
|
||||
let methods = [InterfaceMethod<
|
||||
"",
|
||||
"::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Extension, 1>, 1>",
|
||||
"getExtensions">];
|
||||
}
|
||||
def QueryCapabilityInterface : OpInterface<"QueryCapabilityInterface"> {
|
||||
let methods = [InterfaceMethod<
|
||||
"",
|
||||
"::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Capability, 1>, 1>",
|
||||
"getCapabilities">];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V extension definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1216,7 +1353,22 @@ def SPV_OpcodeAttr :
|
|||
|
||||
// Base class for all SPIR-V ops.
|
||||
class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<SPV_Dialect, mnemonic, traits> {
|
||||
Op<SPV_Dialect, mnemonic, !listconcat(traits, [
|
||||
// TODO(antiagainst): We don't need all of the following traits for
|
||||
// every op; only the suitabble ones should be added automatically
|
||||
// after ODS supports dialect-specific contents.
|
||||
DeclareOpInterfaceMethods<QueryMinVersionInterface>,
|
||||
DeclareOpInterfaceMethods<QueryMaxVersionInterface>,
|
||||
DeclareOpInterfaceMethods<QueryExtensionInterface>,
|
||||
DeclareOpInterfaceMethods<QueryCapabilityInterface>
|
||||
])> {
|
||||
// Availability specification for this op itself.
|
||||
list<Availability> availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
Extension<[]>,
|
||||
Capability<[]>
|
||||
];
|
||||
|
||||
// For each SPIR-V op, the following static functions need to be defined
|
||||
// in SPVOps.cpp:
|
||||
|
|
|
@ -53,6 +53,13 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
|
|||
```
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_3>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
Extension<[]>,
|
||||
Capability<[SPV_C_GroupNonUniformBallot]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_ScopeAttr:$execution_scope,
|
||||
SPV_Bool:$predicate
|
||||
|
|
|
@ -21,18 +21,23 @@ class OpBuilder;
|
|||
|
||||
namespace spirv {
|
||||
|
||||
// TableGen'erated operation interfaces for querying versions, extensions, and
|
||||
// capabilities.
|
||||
#include "mlir/Dialect/SPIRV/SPIRVAvailability.h.inc"
|
||||
|
||||
// TablenGen'erated operation declarations.
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h.inc"
|
||||
|
||||
/// Following methods are auto-generated.
|
||||
///
|
||||
/// Get the name used in the Op to refer to an enum value of the given
|
||||
/// `EnumClass`.
|
||||
/// template <typename EnumClass> StringRef attributeName();
|
||||
///
|
||||
/// Get the function that can be used to symbolize an enum value.
|
||||
/// template <typename EnumClass>
|
||||
/// Optional<EnumClass> (*)(StringRef) symbolizeEnum();
|
||||
// TableGen'erated helper functions.
|
||||
//
|
||||
// Get the name used in the Op to refer to an enum value of the given
|
||||
// `EnumClass`.
|
||||
// template <typename EnumClass> StringRef attributeName();
|
||||
//
|
||||
// Get the function that can be used to symbolize an enum value.
|
||||
// template <typename EnumClass>
|
||||
// Optional<EnumClass> (*)(StringRef) symbolizeEnum();
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
|
||||
|
||||
} // end namespace spirv
|
||||
|
|
|
@ -15,6 +15,7 @@ add_llvm_library(MLIRSPIRV
|
|||
)
|
||||
|
||||
add_dependencies(MLIRSPIRV
|
||||
MLIRSPIRVAvailabilityIncGen
|
||||
MLIRSPIRVCanonicalizationIncGen
|
||||
MLIRSPIRVEnumsIncGen
|
||||
MLIRSPIRVLoweringStructGen
|
||||
|
|
|
@ -3063,8 +3063,16 @@ static LogicalResult verify(spirv::VariableOp varOp) {
|
|||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
// TableGen'erated operation interfaces for querying versions, extensions, and
|
||||
// capabilities.
|
||||
#include "mlir/Dialect/SPIRV/SPIRVAvailability.cpp.inc"
|
||||
|
||||
// TablenGen'erated operation definitions.
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
|
||||
|
||||
// TableGen'erated operation availability interface implementations.
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOpAvailabilityImpl.inc"
|
||||
|
||||
} // namespace spirv
|
||||
} // namespace mlir
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
add_subdirectory(Dialect)
|
||||
add_subdirectory(EDSC)
|
||||
add_subdirectory(mlir-cpu-runner)
|
||||
add_subdirectory(SDBM)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(SPIRV)
|
|
@ -0,0 +1,14 @@
|
|||
add_llvm_library(MLIRSPIRVTestPasses
|
||||
TestAvailability.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
|
||||
)
|
||||
|
||||
target_link_libraries(MLIRSPIRVTestPasses
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRSPIRV
|
||||
MLIRSupport
|
||||
)
|
|
@ -0,0 +1,73 @@
|
|||
//===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// A pass for testing SPIR-V op availability.
|
||||
struct TestAvailability : public FunctionPass<TestAvailability> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void TestAvailability::runOnFunction() {
|
||||
auto f = getFunction();
|
||||
llvm::outs() << f.getName() << "\n";
|
||||
|
||||
Dialect *spvDialect = getContext().getRegisteredDialect("spv");
|
||||
|
||||
f.getOperation()->walk([&](Operation *op) {
|
||||
if (op->getDialect() != spvDialect)
|
||||
return WalkResult::advance();
|
||||
|
||||
auto &os = llvm::outs();
|
||||
|
||||
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
|
||||
os << " min version: "
|
||||
<< spirv::stringifyVersion(minVersion.getMinVersion()) << "\n";
|
||||
|
||||
if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
|
||||
os << " max version: "
|
||||
<< spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n";
|
||||
|
||||
if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
|
||||
os << " extensions: [";
|
||||
for (const auto &exts : extension.getExtensions()) {
|
||||
os << " [";
|
||||
interleaveComma(exts, os, [&](spirv::Extension ext) {
|
||||
os << spirv::stringifyExtension(ext);
|
||||
});
|
||||
os << "]";
|
||||
}
|
||||
os << " ]\n";
|
||||
}
|
||||
|
||||
if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
|
||||
os << " capabilities: [";
|
||||
for (const auto &caps : capability.getCapabilities()) {
|
||||
os << " [";
|
||||
interleaveComma(caps, os, [&](spirv::Capability cap) {
|
||||
os << spirv::stringifyCapability(cap);
|
||||
});
|
||||
os << "]";
|
||||
}
|
||||
os << " ]\n";
|
||||
}
|
||||
os.flush();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
|
||||
static PassRegistration<TestAvailability> pass("test-spirv-op-availability",
|
||||
"Test SPIR-V op availability");
|
|
@ -0,0 +1,31 @@
|
|||
// RUN: mlir-opt -disable-pass-threading -test-spirv-op-availability %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: iadd
|
||||
func @iadd(%arg: i32) -> i32 {
|
||||
// CHECK: min version: V_1_0
|
||||
// CHECK: max version: V_1_5
|
||||
// CHECK: extensions: [ ]
|
||||
// CHECK: capabilities: [ ]
|
||||
%0 = spv.IAdd %arg, %arg: i32
|
||||
return %0: i32
|
||||
}
|
||||
|
||||
// CHECK: atomic_compare_exchange_weak
|
||||
func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 {
|
||||
// CHECK: min version: V_1_0
|
||||
// CHECK: max version: V_1_3
|
||||
// CHECK: extensions: [ ]
|
||||
// CHECK: capabilities: [ [Kernel] ]
|
||||
%0 = spv.AtomicCompareExchangeWeak "Workgroup" "Release" "Acquire" %ptr, %value, %comparator: !spv.ptr<i32, Workgroup>
|
||||
return %0: i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: subgroup_ballot
|
||||
func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
|
||||
// CHECK: min version: V_1_3
|
||||
// CHECK: max version: V_1_5
|
||||
// CHECK: extensions: [ ]
|
||||
// CHECK: capabilities: [ [GroupNonUniformBallot] ]
|
||||
%0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
|
||||
return %0: vector<4xi32>
|
||||
}
|
|
@ -41,6 +41,7 @@ set(LIBS
|
|||
MLIRROCDLIR
|
||||
MLIRSPIRV
|
||||
MLIRStandardToSPIRVTransforms
|
||||
MLIRSPIRVTestPasses
|
||||
MLIRSPIRVTransforms
|
||||
MLIRStandardOps
|
||||
MLIRStandardToLLVM
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include "mlir/Support/StringExtras.h"
|
||||
#include "mlir/TableGen/Attribute.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
@ -43,6 +44,233 @@ using mlir::tblgen::NamedAttribute;
|
|||
using mlir::tblgen::NamedTypeConstraint;
|
||||
using mlir::tblgen::Operator;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Availability Wrapper Class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// Wrapper class with helper methods for accessing availability defined in
|
||||
// TableGen.
|
||||
class Availability {
|
||||
public:
|
||||
explicit Availability(const Record *def);
|
||||
|
||||
// Returns the name of the direct TableGen class for this availability
|
||||
// instance.
|
||||
StringRef getClass() const;
|
||||
|
||||
// Returns the generated C++ interface's class name.
|
||||
StringRef getInterfaceClassName() const;
|
||||
|
||||
// Returns the generated C++ interface's description.
|
||||
StringRef getInterfaceDescription() const;
|
||||
|
||||
// Returns the name of the query function insided the generated C++ interface.
|
||||
StringRef getQueryFnName() const;
|
||||
|
||||
// Returns the return type of the query function insided the generated C++
|
||||
// interface.
|
||||
StringRef getQueryFnRetType() const;
|
||||
|
||||
// Returns the code for merging availability requirements.
|
||||
StringRef getMergeActionCode() const;
|
||||
|
||||
// Returns the initializer expression for initializing the final availability
|
||||
// requirements.
|
||||
StringRef getMergeInitializer() const;
|
||||
|
||||
// Returns the C++ type for an availability instance.
|
||||
StringRef getMergeInstanceType() const;
|
||||
|
||||
// Returns the concrete availability instance carried in this case.
|
||||
StringRef getMergeInstance() const;
|
||||
|
||||
private:
|
||||
// The TableGen definition of this availability.
|
||||
const llvm::Record *def;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Availability::Availability(const llvm::Record *def) : def(def) {
|
||||
assert(def->isSubClassOf("Availability") &&
|
||||
"must be subclass of TableGen 'Availability' class");
|
||||
}
|
||||
|
||||
StringRef Availability::getClass() const {
|
||||
SmallVector<Record *, 1> parentClass;
|
||||
def->getDirectSuperClasses(parentClass);
|
||||
if (parentClass.size() != 1) {
|
||||
PrintFatalError(def->getLoc(),
|
||||
"expected to only have one direct superclass");
|
||||
}
|
||||
return parentClass.front()->getName();
|
||||
}
|
||||
|
||||
StringRef Availability::getInterfaceClassName() const {
|
||||
return def->getValueAsString("interfaceName");
|
||||
}
|
||||
|
||||
StringRef Availability::getInterfaceDescription() const {
|
||||
return def->getValueAsString("interfaceDescription");
|
||||
}
|
||||
|
||||
StringRef Availability::getQueryFnRetType() const {
|
||||
return def->getValueAsString("queryFnRetType");
|
||||
}
|
||||
|
||||
StringRef Availability::getQueryFnName() const {
|
||||
return def->getValueAsString("queryFnName");
|
||||
}
|
||||
|
||||
StringRef Availability::getMergeActionCode() const {
|
||||
return def->getValueAsString("mergeAction");
|
||||
}
|
||||
|
||||
StringRef Availability::getMergeInitializer() const {
|
||||
return def->getValueAsString("initializer");
|
||||
}
|
||||
|
||||
StringRef Availability::getMergeInstanceType() const {
|
||||
return def->getValueAsString("instanceType");
|
||||
}
|
||||
|
||||
StringRef Availability::getMergeInstance() const {
|
||||
return def->getValueAsString("instance");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Availability Interface Definitions AutoGen
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void emitInterfaceDef(const Availability &availability,
|
||||
raw_ostream &os) {
|
||||
StringRef methodName = availability.getQueryFnName();
|
||||
os << availability.getQueryFnRetType() << " "
|
||||
<< availability.getInterfaceClassName() << "::" << methodName << "() {\n"
|
||||
<< " return getImpl()->" << methodName << "(getOperation());\n"
|
||||
<< "}\n";
|
||||
}
|
||||
|
||||
static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
llvm::emitSourceFileHeader("Availability Interface Definitions", os);
|
||||
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
|
||||
SmallVector<const Record *, 1> handledClasses;
|
||||
for (const Record *def : defs) {
|
||||
SmallVector<Record *, 1> parent;
|
||||
def->getDirectSuperClasses(parent);
|
||||
if (parent.size() != 1) {
|
||||
PrintFatalError(def->getLoc(),
|
||||
"expected to only have one direct superclass");
|
||||
}
|
||||
if (llvm::is_contained(handledClasses, parent.front()))
|
||||
continue;
|
||||
|
||||
Availability availability(def);
|
||||
emitInterfaceDef(availability, os);
|
||||
handledClasses.push_back(parent.front());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Availability Interface Declarations AutoGen
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
|
||||
os << " class Concept {\n"
|
||||
<< " public:\n"
|
||||
<< " virtual ~Concept() = default;\n"
|
||||
<< " virtual " << availability.getQueryFnRetType() << " "
|
||||
<< availability.getQueryFnName() << "(Operation *tblgen_opaque_op) = 0;\n"
|
||||
<< " };\n";
|
||||
}
|
||||
|
||||
static void emitModelDecl(const Availability &availability, raw_ostream &os) {
|
||||
os << " template<typename ConcreteOp>\n";
|
||||
os << " class Model : public Concept {\n"
|
||||
<< " public:\n"
|
||||
<< " " << availability.getQueryFnRetType() << " "
|
||||
<< availability.getQueryFnName()
|
||||
<< "(Operation *tblgen_opaque_op) final {\n"
|
||||
<< " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
|
||||
<< " (void)op;\n"
|
||||
// Forward to the method on the concrete operation type.
|
||||
<< " return op." << availability.getQueryFnName() << "();\n"
|
||||
<< " }\n"
|
||||
<< " };\n";
|
||||
}
|
||||
|
||||
static void emitInterfaceDecl(const Availability &availability,
|
||||
raw_ostream &os) {
|
||||
StringRef interfaceName = availability.getInterfaceClassName();
|
||||
std::string interfaceTraitsName = formatv("{0}Traits", interfaceName);
|
||||
|
||||
// Emit the traits struct containing the concept and model declarations.
|
||||
os << "namespace detail {\n"
|
||||
<< "struct " << interfaceTraitsName << " {\n";
|
||||
emitConceptDecl(availability, os);
|
||||
os << '\n';
|
||||
emitModelDecl(availability, os);
|
||||
os << "};\n} // end namespace detail\n\n";
|
||||
|
||||
// Emit the main interface class declaration.
|
||||
os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";
|
||||
os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
|
||||
"public:\n"
|
||||
" using OpInterface<{1}, detail::{2}>::OpInterface;\n",
|
||||
interfaceName, interfaceName, interfaceTraitsName);
|
||||
|
||||
// Emit query function declaration.
|
||||
os << " " << availability.getQueryFnRetType() << " "
|
||||
<< availability.getQueryFnName() << "();\n";
|
||||
os << "};\n\n";
|
||||
}
|
||||
|
||||
static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
llvm::emitSourceFileHeader("Availability Interface Declarations", os);
|
||||
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
|
||||
SmallVector<const Record *, 4> handledClasses;
|
||||
for (const Record *def : defs) {
|
||||
SmallVector<Record *, 1> parent;
|
||||
def->getDirectSuperClasses(parent);
|
||||
if (parent.size() != 1) {
|
||||
PrintFatalError(def->getLoc(),
|
||||
"expected to only have one direct superclass");
|
||||
}
|
||||
if (llvm::is_contained(handledClasses, parent.front()))
|
||||
continue;
|
||||
|
||||
Availability avail(def);
|
||||
emitInterfaceDecl(avail, os);
|
||||
handledClasses.push_back(parent.front());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Availability Interface Hook Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Registers the operation interface generator to mlir-tblgen.
|
||||
static mlir::GenRegistration
|
||||
genInterfaceDecls("gen-avail-interface-decls",
|
||||
"Generate availability interface declarations",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitInterfaceDecls(records, os);
|
||||
});
|
||||
|
||||
// Registers the operation interface generator to mlir-tblgen.
|
||||
static mlir::GenRegistration
|
||||
genInterfaceDefs("gen-avail-interface-defs",
|
||||
"Generate op interface definitions",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitInterfaceDefs(records, os);
|
||||
});
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Serialization AutoGen
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -650,6 +878,17 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper,
|
|||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Serialization Hook Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static mlir::GenRegistration genSerialization(
|
||||
"gen-spirv-serialization",
|
||||
"Generate SPIR-V (de)serialization utilities and functions",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitSerializationFns(records, os);
|
||||
});
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op Utils AutoGen
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -707,19 +946,92 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Hook Registration
|
||||
// Op Utils Hook Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static mlir::GenRegistration genSerialization(
|
||||
"gen-spirv-serialization",
|
||||
"Generate SPIR-V (de)serialization utilities and functions",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitSerializationFns(records, os);
|
||||
});
|
||||
|
||||
static mlir::GenRegistration
|
||||
genOpUtils("gen-spirv-op-utils",
|
||||
"Generate SPIR-V operation utility definitions",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitOpUtils(records, os);
|
||||
});
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V Availability Impl AutoGen
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Returns the availability spec of the given `def`.
|
||||
std::vector<Availability> getAvailabilities(const Record &def) {
|
||||
std::vector<Availability> availabilities;
|
||||
if (auto *availListInit = def.getValueAsListInit("availability")) {
|
||||
availabilities.reserve(availListInit->size());
|
||||
for (auto *availInit : *availListInit)
|
||||
availabilities.emplace_back(
|
||||
llvm::cast<llvm::DefInit>(availInit)->getDef());
|
||||
}
|
||||
return availabilities;
|
||||
}
|
||||
|
||||
static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
|
||||
mlir::tblgen::FmtContext fctx;
|
||||
fctx.addSubst("overall", "overall");
|
||||
|
||||
std::vector<Availability> opAvailabilities =
|
||||
getAvailabilities(srcOp.getDef());
|
||||
|
||||
// First collect all availablity classes this op should implement.
|
||||
// All availablity instances keep information for the generated interface and
|
||||
// the instance's specific requirement. Here we remember a random instance so
|
||||
// we can get the information regarding the generated interface.
|
||||
llvm::StringMap<Availability> availClasses;
|
||||
for (const Availability &avail : opAvailabilities)
|
||||
availClasses.try_emplace(avail.getClass(), avail);
|
||||
|
||||
// Then generate implementation for each availability class.
|
||||
for (const auto &availClass : availClasses) {
|
||||
StringRef availClassName = availClass.getKey();
|
||||
Availability avail = availClass.getValue();
|
||||
|
||||
// Generate the implementation method signature.
|
||||
os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(),
|
||||
srcOp.getCppClassName(), avail.getQueryFnName());
|
||||
|
||||
// Create the variable for the final requirement and initialize it.
|
||||
os << formatv(" {0} overall = {1};\n", avail.getQueryFnRetType(),
|
||||
avail.getMergeInitializer());
|
||||
|
||||
// Update with the op's specific availability spec.
|
||||
for (const Availability &avail : opAvailabilities)
|
||||
if (avail.getClass() == availClassName) {
|
||||
os << " "
|
||||
<< tgfmt(avail.getMergeActionCode(),
|
||||
&fctx.addSubst("instance", avail.getMergeInstance()))
|
||||
<< ";\n";
|
||||
}
|
||||
os << " return overall;\n";
|
||||
os << "}\n";
|
||||
}
|
||||
}
|
||||
|
||||
static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os);
|
||||
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
|
||||
for (const auto *def : defs) {
|
||||
Operator op(def);
|
||||
emitAvailabilityImpl(op, os);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op Availability Implementation Hook Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static mlir::GenRegistration
|
||||
genOpAvailabilityImpl("gen-spirv-avail-impls",
|
||||
"Generate SPIR-V operation utility definitions",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitAvailabilityImpl(records, os);
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue