forked from OSchip/llvm-project
[mlir][spirv] Add a field for client API in target environment
SPIR-V can be directly consumed by APIs like Vulkan and OpenCL, where we can use the capability list to diffferentiate. It can also be used as a compilation target to transcompile to shading languages like WGSL to target WebGPU. We have no way to tell that with just the capability list, so we cannot perform certain transformations only applicable to those targets thus far. So this commit add a field in the target environment to indicate the client API for such purposes. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D138732
This commit is contained in:
parent
410c1f6269
commit
e672f5126f
|
@ -138,9 +138,11 @@ public:
|
|||
using Base::Base;
|
||||
|
||||
/// Gets a TargetEnvAttr instance.
|
||||
static TargetEnvAttr get(VerCapExtAttr triple, Vendor vendorID,
|
||||
DeviceType deviceType, uint32_t deviceId,
|
||||
ResourceLimitsAttr limits);
|
||||
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits,
|
||||
ClientAPI clientAPI = ClientAPI::Unknown,
|
||||
Vendor vendorID = Vendor::Unknown,
|
||||
DeviceType deviceType = DeviceType::Unknown,
|
||||
uint32_t deviceId = kUnknownDeviceID);
|
||||
|
||||
/// Returns the attribute kind's name (without the 'spirv.' prefix).
|
||||
static StringRef getKindName();
|
||||
|
@ -161,6 +163,9 @@ public:
|
|||
/// Returns the target capabilities as an integer array attribute.
|
||||
ArrayAttr getCapabilitiesAttr();
|
||||
|
||||
/// Returns the client API.
|
||||
ClientAPI getClientAPI() const;
|
||||
|
||||
/// Returns the vendor ID.
|
||||
Vendor getVendorID() const;
|
||||
|
||||
|
|
|
@ -267,7 +267,7 @@ def SPIRV_DT_IntegratedGPU : I32EnumAttrCase<"IntegratedGPU", 2>;
|
|||
// An accelerator other than GPU or CPU
|
||||
def SPIRV_DT_Other : I32EnumAttrCase<"Other", 3>;
|
||||
// Information missing.
|
||||
def SPIRV_DT_Unknown : I32EnumAttrCase<"Unknown", 4>;
|
||||
def SPIRV_DT_Unknown : I32EnumAttrCase<"Unknown", 0xffffffff>;
|
||||
|
||||
def SPIRV_DeviceTypeAttr : SPIRV_I32EnumAttr<
|
||||
"DeviceType", "valid SPIR-V device types", "device_type", [
|
||||
|
@ -283,7 +283,7 @@ def SPIRV_V_Intel : I32EnumAttrCase<"Intel", 4>;
|
|||
def SPIRV_V_NVIDIA : I32EnumAttrCase<"NVIDIA", 5>;
|
||||
def SPIRV_V_Qualcomm : I32EnumAttrCase<"Qualcomm", 6>;
|
||||
def SPIRV_V_SwiftShader : I32EnumAttrCase<"SwiftShader", 7>;
|
||||
def SPIRV_V_Unknown : I32EnumAttrCase<"Unknown", 0xff>;
|
||||
def SPIRV_V_Unknown : I32EnumAttrCase<"Unknown", 0xffffffff>;
|
||||
|
||||
def SPIRV_VendorAttr : SPIRV_I32EnumAttr<
|
||||
"Vendor", "recognized SPIR-V vendor strings", "vendor", [
|
||||
|
@ -292,6 +292,18 @@ def SPIRV_VendorAttr : SPIRV_I32EnumAttr<
|
|||
SPIRV_V_Unknown
|
||||
]>;
|
||||
|
||||
def SPIRV_CA_Metal : I32EnumAttrCase<"Metal", 0>;
|
||||
def SPIRV_CA_OpenCL : I32EnumAttrCase<"OpenCL", 1>;
|
||||
def SPIRV_CA_Vulkan : I32EnumAttrCase<"Vulkan", 2>;
|
||||
def SPIRV_CA_WebGPU : I32EnumAttrCase<"WebGPU", 3>;
|
||||
def SPIRV_CA_Unknown : I32EnumAttrCase<"Unknown", 0xffffffff>;
|
||||
|
||||
def SPIRV_ClientAPIAttr : SPIRV_I32EnumAttr<
|
||||
"ClientAPI", "recognized SPIR-V client APIs", "client_api", [
|
||||
SPIRV_CA_Metal, SPIRV_CA_OpenCL, SPIRV_CA_Vulkan, SPIRV_CA_WebGPU,
|
||||
SPIRV_CA_Unknown
|
||||
]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V extension definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -82,17 +82,18 @@ struct VerCapExtAttributeStorage : public AttributeStorage {
|
|||
};
|
||||
|
||||
struct TargetEnvAttributeStorage : public AttributeStorage {
|
||||
using KeyTy = std::tuple<Attribute, Vendor, DeviceType, uint32_t, Attribute>;
|
||||
using KeyTy =
|
||||
std::tuple<Attribute, ClientAPI, Vendor, DeviceType, uint32_t, Attribute>;
|
||||
|
||||
TargetEnvAttributeStorage(Attribute triple, Vendor vendorID,
|
||||
DeviceType deviceType, uint32_t deviceID,
|
||||
Attribute limits)
|
||||
: triple(triple), limits(limits), vendorID(vendorID),
|
||||
deviceType(deviceType), deviceID(deviceID) {}
|
||||
TargetEnvAttributeStorage(Attribute triple, ClientAPI clientAPI,
|
||||
Vendor vendorID, DeviceType deviceType,
|
||||
uint32_t deviceID, Attribute limits)
|
||||
: triple(triple), limits(limits), clientAPI(clientAPI),
|
||||
vendorID(vendorID), deviceType(deviceType), deviceID(deviceID) {}
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key ==
|
||||
std::make_tuple(triple, vendorID, deviceType, deviceID, limits);
|
||||
return key == std::make_tuple(triple, clientAPI, vendorID, deviceType,
|
||||
deviceID, limits);
|
||||
}
|
||||
|
||||
static TargetEnvAttributeStorage *
|
||||
|
@ -100,11 +101,12 @@ struct TargetEnvAttributeStorage : public AttributeStorage {
|
|||
return new (allocator.allocate<TargetEnvAttributeStorage>())
|
||||
TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key),
|
||||
std::get<2>(key), std::get<3>(key),
|
||||
std::get<4>(key));
|
||||
std::get<4>(key), std::get<5>(key));
|
||||
}
|
||||
|
||||
Attribute triple;
|
||||
Attribute limits;
|
||||
ClientAPI clientAPI;
|
||||
Vendor vendorID;
|
||||
DeviceType deviceType;
|
||||
uint32_t deviceID;
|
||||
|
@ -282,14 +284,13 @@ spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
|||
// TargetEnvAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
|
||||
Vendor vendorID,
|
||||
DeviceType deviceType,
|
||||
uint32_t deviceID,
|
||||
ResourceLimitsAttr limits) {
|
||||
spirv::TargetEnvAttr spirv::TargetEnvAttr::get(
|
||||
spirv::VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI,
|
||||
Vendor vendorID, DeviceType deviceType, uint32_t deviceID) {
|
||||
assert(triple && limits && "expected valid triple and limits");
|
||||
MLIRContext *context = triple.getContext();
|
||||
return Base::get(context, triple, vendorID, deviceType, deviceID, limits);
|
||||
return Base::get(context, triple, clientAPI, vendorID, deviceType, deviceID,
|
||||
limits);
|
||||
}
|
||||
|
||||
StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
|
||||
|
@ -318,6 +319,10 @@ ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() {
|
|||
return getTripleAttr().getCapabilitiesAttr();
|
||||
}
|
||||
|
||||
spirv::ClientAPI spirv::TargetEnvAttr::getClientAPI() const {
|
||||
return getImpl()->clientAPI;
|
||||
}
|
||||
|
||||
spirv::Vendor spirv::TargetEnvAttr::getVendorID() const {
|
||||
return getImpl()->vendorID;
|
||||
}
|
||||
|
@ -523,6 +528,22 @@ static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
|
|||
if (parser.parseAttribute(tripleAttr) || parser.parseComma())
|
||||
return {};
|
||||
|
||||
auto clientAPI = spirv::ClientAPI::Unknown;
|
||||
if (succeeded(parser.parseOptionalKeyword("api"))) {
|
||||
if (parser.parseEqual())
|
||||
return {};
|
||||
auto loc = parser.getCurrentLocation();
|
||||
StringRef apiStr;
|
||||
if (parser.parseKeyword(&apiStr))
|
||||
return {};
|
||||
if (auto apiSymbol = spirv::symbolizeClientAPI(apiStr))
|
||||
clientAPI = *apiSymbol;
|
||||
else
|
||||
parser.emitError(loc, "unknown client API: ") << apiStr;
|
||||
if (parser.parseComma())
|
||||
return {};
|
||||
}
|
||||
|
||||
// Parse [vendor[:device-type[:device-id]]]
|
||||
Vendor vendorID = Vendor::Unknown;
|
||||
DeviceType deviceType = DeviceType::Unknown;
|
||||
|
@ -531,22 +552,20 @@ static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
|
|||
auto loc = parser.getCurrentLocation();
|
||||
StringRef vendorStr;
|
||||
if (succeeded(parser.parseOptionalKeyword(&vendorStr))) {
|
||||
if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) {
|
||||
if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr))
|
||||
vendorID = *vendorSymbol;
|
||||
} else {
|
||||
else
|
||||
parser.emitError(loc, "unknown vendor: ") << vendorStr;
|
||||
}
|
||||
|
||||
if (succeeded(parser.parseOptionalColon())) {
|
||||
loc = parser.getCurrentLocation();
|
||||
StringRef deviceTypeStr;
|
||||
if (parser.parseKeyword(&deviceTypeStr))
|
||||
return {};
|
||||
if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) {
|
||||
if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr))
|
||||
deviceType = *deviceTypeSymbol;
|
||||
} else {
|
||||
else
|
||||
parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
|
||||
}
|
||||
|
||||
if (succeeded(parser.parseOptionalColon())) {
|
||||
loc = parser.getCurrentLocation();
|
||||
|
@ -563,8 +582,8 @@ static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
|
|||
if (parser.parseAttribute(limitsAttr) || parser.parseGreater())
|
||||
return {};
|
||||
|
||||
return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID,
|
||||
limitsAttr);
|
||||
return spirv::TargetEnvAttr::get(tripleAttr, limitsAttr, clientAPI, vendorID,
|
||||
deviceType, deviceID);
|
||||
}
|
||||
|
||||
Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
|
||||
|
@ -616,6 +635,9 @@ static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
|
|||
static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
|
||||
printer << spirv::TargetEnvAttr::getKindName() << "<#spirv.";
|
||||
print(targetEnv.getTripleAttr(), printer);
|
||||
auto clientAPI = targetEnv.getClientAPI();
|
||||
if (clientAPI != spirv::ClientAPI::Unknown)
|
||||
printer << ", api=" << clientAPI;
|
||||
spirv::Vendor vendorID = targetEnv.getVendorID();
|
||||
spirv::DeviceType deviceType = targetEnv.getDeviceType();
|
||||
uint32_t deviceID = targetEnv.getDeviceID();
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
|
@ -170,10 +171,10 @@ spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
|
|||
auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
|
||||
{spirv::Capability::Shader},
|
||||
ArrayRef<Extension>(), context);
|
||||
return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown,
|
||||
spirv::DeviceType::Unknown,
|
||||
spirv::TargetEnvAttr::kUnknownDeviceID,
|
||||
spirv::getDefaultResourceLimits(context));
|
||||
return spirv::TargetEnvAttr::get(
|
||||
triple, spirv::getDefaultResourceLimits(context),
|
||||
spirv::ClientAPI::Unknown, spirv::Vendor::Unknown,
|
||||
spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
|
||||
}
|
||||
|
||||
spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
|
||||
|
|
|
@ -118,6 +118,24 @@ func.func @target_env() attributes {
|
|||
|
||||
// -----
|
||||
|
||||
func.func @target_env_client_api() attributes {
|
||||
// CHECK: spirv.target_env = #spirv.target_env<
|
||||
// CHECK-SAME: #spirv.vce<v1.0, [], []>,
|
||||
// CHECK-SAME: api=Metal,
|
||||
// CHECK-SAME: #spirv.resource_limits<>>
|
||||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, api=Metal, #spirv.resource_limits<>>
|
||||
} { return }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @target_env_client_api() attributes {
|
||||
// CHECK: spirv.target_env = #spirv.target_env
|
||||
// CHECK-NOT: api=
|
||||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, api=Unknown, #spirv.resource_limits<>>
|
||||
} { return }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @target_env_vendor_id() attributes {
|
||||
// CHECK: spirv.target_env = #spirv.target_env<
|
||||
// CHECK-SAME: #spirv.vce<v1.0, [], []>,
|
||||
|
@ -148,6 +166,17 @@ func.func @target_env_vendor_id_device_type_device_id() attributes {
|
|||
|
||||
// -----
|
||||
|
||||
func.func @target_env_client_api_vendor_id_device_type_device_id() attributes {
|
||||
// CHECK: spirv.target_env = #spirv.target_env<
|
||||
// CHECK-SAME: #spirv.vce<v1.0, [], []>,
|
||||
// CHECK-SAME: api=Vulkan,
|
||||
// CHECK-SAME: Qualcomm:IntegratedGPU:100925441,
|
||||
// CHECK-SAME: #spirv.resource_limits<>>
|
||||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, api=Vulkan, Qualcomm:IntegratedGPU:0x6040001, #spirv.resource_limits<>>
|
||||
} { return }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @target_env_extra_fields() attributes {
|
||||
// expected-error @+3 {{expected '>'}}
|
||||
spirv.target_env = #spirv.target_env<
|
||||
|
|
Loading…
Reference in New Issue