[mlir][spirv] Add a field for client API in target environment

SPIR-V can be directly consumed by APIs like Vulkan and OpenCL,
where we can use the capability list to diffferentiate. It can
also be used as a compilation target to transcompile to shading
languages like WGSL to target WebGPU. We have no way to tell
that with just the capability list, so we cannot perform certain
transformations only applicable to those targets thus far. So
this commit add a field in the target environment to indicate
the client API for such purposes.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D138732
This commit is contained in:
Lei Zhang 2022-11-25 21:31:12 +00:00
parent 410c1f6269
commit e672f5126f
5 changed files with 101 additions and 32 deletions

View File

@ -138,9 +138,11 @@ public:
using Base::Base;
/// Gets a TargetEnvAttr instance.
static TargetEnvAttr get(VerCapExtAttr triple, Vendor vendorID,
DeviceType deviceType, uint32_t deviceId,
ResourceLimitsAttr limits);
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits,
ClientAPI clientAPI = ClientAPI::Unknown,
Vendor vendorID = Vendor::Unknown,
DeviceType deviceType = DeviceType::Unknown,
uint32_t deviceId = kUnknownDeviceID);
/// Returns the attribute kind's name (without the 'spirv.' prefix).
static StringRef getKindName();
@ -161,6 +163,9 @@ public:
/// Returns the target capabilities as an integer array attribute.
ArrayAttr getCapabilitiesAttr();
/// Returns the client API.
ClientAPI getClientAPI() const;
/// Returns the vendor ID.
Vendor getVendorID() const;

View File

@ -267,7 +267,7 @@ def SPIRV_DT_IntegratedGPU : I32EnumAttrCase<"IntegratedGPU", 2>;
// An accelerator other than GPU or CPU
def SPIRV_DT_Other : I32EnumAttrCase<"Other", 3>;
// Information missing.
def SPIRV_DT_Unknown : I32EnumAttrCase<"Unknown", 4>;
def SPIRV_DT_Unknown : I32EnumAttrCase<"Unknown", 0xffffffff>;
def SPIRV_DeviceTypeAttr : SPIRV_I32EnumAttr<
"DeviceType", "valid SPIR-V device types", "device_type", [
@ -283,7 +283,7 @@ def SPIRV_V_Intel : I32EnumAttrCase<"Intel", 4>;
def SPIRV_V_NVIDIA : I32EnumAttrCase<"NVIDIA", 5>;
def SPIRV_V_Qualcomm : I32EnumAttrCase<"Qualcomm", 6>;
def SPIRV_V_SwiftShader : I32EnumAttrCase<"SwiftShader", 7>;
def SPIRV_V_Unknown : I32EnumAttrCase<"Unknown", 0xff>;
def SPIRV_V_Unknown : I32EnumAttrCase<"Unknown", 0xffffffff>;
def SPIRV_VendorAttr : SPIRV_I32EnumAttr<
"Vendor", "recognized SPIR-V vendor strings", "vendor", [
@ -292,6 +292,18 @@ def SPIRV_VendorAttr : SPIRV_I32EnumAttr<
SPIRV_V_Unknown
]>;
def SPIRV_CA_Metal : I32EnumAttrCase<"Metal", 0>;
def SPIRV_CA_OpenCL : I32EnumAttrCase<"OpenCL", 1>;
def SPIRV_CA_Vulkan : I32EnumAttrCase<"Vulkan", 2>;
def SPIRV_CA_WebGPU : I32EnumAttrCase<"WebGPU", 3>;
def SPIRV_CA_Unknown : I32EnumAttrCase<"Unknown", 0xffffffff>;
def SPIRV_ClientAPIAttr : SPIRV_I32EnumAttr<
"ClientAPI", "recognized SPIR-V client APIs", "client_api", [
SPIRV_CA_Metal, SPIRV_CA_OpenCL, SPIRV_CA_Vulkan, SPIRV_CA_WebGPU,
SPIRV_CA_Unknown
]>;
//===----------------------------------------------------------------------===//
// SPIR-V extension definitions
//===----------------------------------------------------------------------===//

View File

@ -82,17 +82,18 @@ struct VerCapExtAttributeStorage : public AttributeStorage {
};
struct TargetEnvAttributeStorage : public AttributeStorage {
using KeyTy = std::tuple<Attribute, Vendor, DeviceType, uint32_t, Attribute>;
using KeyTy =
std::tuple<Attribute, ClientAPI, Vendor, DeviceType, uint32_t, Attribute>;
TargetEnvAttributeStorage(Attribute triple, Vendor vendorID,
DeviceType deviceType, uint32_t deviceID,
Attribute limits)
: triple(triple), limits(limits), vendorID(vendorID),
deviceType(deviceType), deviceID(deviceID) {}
TargetEnvAttributeStorage(Attribute triple, ClientAPI clientAPI,
Vendor vendorID, DeviceType deviceType,
uint32_t deviceID, Attribute limits)
: triple(triple), limits(limits), clientAPI(clientAPI),
vendorID(vendorID), deviceType(deviceType), deviceID(deviceID) {}
bool operator==(const KeyTy &key) const {
return key ==
std::make_tuple(triple, vendorID, deviceType, deviceID, limits);
return key == std::make_tuple(triple, clientAPI, vendorID, deviceType,
deviceID, limits);
}
static TargetEnvAttributeStorage *
@ -100,11 +101,12 @@ struct TargetEnvAttributeStorage : public AttributeStorage {
return new (allocator.allocate<TargetEnvAttributeStorage>())
TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key),
std::get<2>(key), std::get<3>(key),
std::get<4>(key));
std::get<4>(key), std::get<5>(key));
}
Attribute triple;
Attribute limits;
ClientAPI clientAPI;
Vendor vendorID;
DeviceType deviceType;
uint32_t deviceID;
@ -282,14 +284,13 @@ spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
// TargetEnvAttr
//===----------------------------------------------------------------------===//
spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
Vendor vendorID,
DeviceType deviceType,
uint32_t deviceID,
ResourceLimitsAttr limits) {
spirv::TargetEnvAttr spirv::TargetEnvAttr::get(
spirv::VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI,
Vendor vendorID, DeviceType deviceType, uint32_t deviceID) {
assert(triple && limits && "expected valid triple and limits");
MLIRContext *context = triple.getContext();
return Base::get(context, triple, vendorID, deviceType, deviceID, limits);
return Base::get(context, triple, clientAPI, vendorID, deviceType, deviceID,
limits);
}
StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
@ -318,6 +319,10 @@ ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() {
return getTripleAttr().getCapabilitiesAttr();
}
spirv::ClientAPI spirv::TargetEnvAttr::getClientAPI() const {
return getImpl()->clientAPI;
}
spirv::Vendor spirv::TargetEnvAttr::getVendorID() const {
return getImpl()->vendorID;
}
@ -523,6 +528,22 @@ static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
if (parser.parseAttribute(tripleAttr) || parser.parseComma())
return {};
auto clientAPI = spirv::ClientAPI::Unknown;
if (succeeded(parser.parseOptionalKeyword("api"))) {
if (parser.parseEqual())
return {};
auto loc = parser.getCurrentLocation();
StringRef apiStr;
if (parser.parseKeyword(&apiStr))
return {};
if (auto apiSymbol = spirv::symbolizeClientAPI(apiStr))
clientAPI = *apiSymbol;
else
parser.emitError(loc, "unknown client API: ") << apiStr;
if (parser.parseComma())
return {};
}
// Parse [vendor[:device-type[:device-id]]]
Vendor vendorID = Vendor::Unknown;
DeviceType deviceType = DeviceType::Unknown;
@ -531,22 +552,20 @@ static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
auto loc = parser.getCurrentLocation();
StringRef vendorStr;
if (succeeded(parser.parseOptionalKeyword(&vendorStr))) {
if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) {
if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr))
vendorID = *vendorSymbol;
} else {
else
parser.emitError(loc, "unknown vendor: ") << vendorStr;
}
if (succeeded(parser.parseOptionalColon())) {
loc = parser.getCurrentLocation();
StringRef deviceTypeStr;
if (parser.parseKeyword(&deviceTypeStr))
return {};
if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) {
if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr))
deviceType = *deviceTypeSymbol;
} else {
else
parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
}
if (succeeded(parser.parseOptionalColon())) {
loc = parser.getCurrentLocation();
@ -563,8 +582,8 @@ static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
if (parser.parseAttribute(limitsAttr) || parser.parseGreater())
return {};
return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID,
limitsAttr);
return spirv::TargetEnvAttr::get(tripleAttr, limitsAttr, clientAPI, vendorID,
deviceType, deviceID);
}
Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
@ -616,6 +635,9 @@ static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
printer << spirv::TargetEnvAttr::getKindName() << "<#spirv.";
print(targetEnv.getTripleAttr(), printer);
auto clientAPI = targetEnv.getClientAPI();
if (clientAPI != spirv::ClientAPI::Unknown)
printer << ", api=" << clientAPI;
spirv::Vendor vendorID = targetEnv.getVendorID();
spirv::DeviceType deviceType = targetEnv.getDeviceType();
uint32_t deviceID = targetEnv.getDeviceID();

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/FunctionInterfaces.h"
@ -170,10 +171,10 @@ spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
{spirv::Capability::Shader},
ArrayRef<Extension>(), context);
return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown,
spirv::DeviceType::Unknown,
spirv::TargetEnvAttr::kUnknownDeviceID,
spirv::getDefaultResourceLimits(context));
return spirv::TargetEnvAttr::get(
triple, spirv::getDefaultResourceLimits(context),
spirv::ClientAPI::Unknown, spirv::Vendor::Unknown,
spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
}
spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {

View File

@ -118,6 +118,24 @@ func.func @target_env() attributes {
// -----
func.func @target_env_client_api() attributes {
// CHECK: spirv.target_env = #spirv.target_env<
// CHECK-SAME: #spirv.vce<v1.0, [], []>,
// CHECK-SAME: api=Metal,
// CHECK-SAME: #spirv.resource_limits<>>
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, api=Metal, #spirv.resource_limits<>>
} { return }
// -----
func.func @target_env_client_api() attributes {
// CHECK: spirv.target_env = #spirv.target_env
// CHECK-NOT: api=
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, api=Unknown, #spirv.resource_limits<>>
} { return }
// -----
func.func @target_env_vendor_id() attributes {
// CHECK: spirv.target_env = #spirv.target_env<
// CHECK-SAME: #spirv.vce<v1.0, [], []>,
@ -148,6 +166,17 @@ func.func @target_env_vendor_id_device_type_device_id() attributes {
// -----
func.func @target_env_client_api_vendor_id_device_type_device_id() attributes {
// CHECK: spirv.target_env = #spirv.target_env<
// CHECK-SAME: #spirv.vce<v1.0, [], []>,
// CHECK-SAME: api=Vulkan,
// CHECK-SAME: Qualcomm:IntegratedGPU:100925441,
// CHECK-SAME: #spirv.resource_limits<>>
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, api=Vulkan, Qualcomm:IntegratedGPU:0x6040001, #spirv.resource_limits<>>
} { return }
// -----
func.func @target_env_extra_fields() attributes {
// expected-error @+3 {{expected '>'}}
spirv.target_env = #spirv.target_env<