[mlir][spirv] Expose more query APIs directly on TargetEnv

This allows us to omit one level of indirection when querying
the information from the underlying attribute.

Reviewed By: hanchung, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D91080
This commit is contained in:
Lei Zhang 2020-11-09 17:57:20 -05:00
parent dbfa69c502
commit 21eb8127f4
4 changed files with 48 additions and 16 deletions

View File

@ -139,10 +139,10 @@ public:
static StringRef getKindName();
/// Returns the (version, capabilities, extensions) triple attribute.
VerCapExtAttr getTripleAttr();
VerCapExtAttr getTripleAttr() const;
/// Returns the target version.
Version getVersion();
Version getVersion() const;
/// Returns the target extensions.
VerCapExtAttr::ext_range getExtensions();
@ -155,16 +155,16 @@ public:
ArrayAttr getCapabilitiesAttr();
/// Returns the vendor ID.
Vendor getVendorID();
Vendor getVendorID() const;
/// Returns the device type.
DeviceType getDeviceType();
DeviceType getDeviceType() const;
/// Returns the device ID.
uint32_t getDeviceID();
uint32_t getDeviceID() const;
/// Returns the target resource limits.
ResourceLimitsAttr getResourceLimits();
ResourceLimitsAttr getResourceLimits() const;
static LogicalResult
verifyConstructionInvariants(Location loc, VerCapExtAttr triple,

View File

@ -29,7 +29,7 @@ class TargetEnv {
public:
explicit TargetEnv(TargetEnvAttr targetAttr);
Version getVersion();
Version getVersion() const;
/// Returns true if the given capability is allowed.
bool allows(Capability) const;
@ -43,9 +43,23 @@ public:
/// Returns llvm::None otherwise.
Optional<Extension> allows(ArrayRef<Extension>) const;
/// Returns the vendor ID.
Vendor getVendorID() const;
/// Returns the device type.
DeviceType getDeviceType() const;
/// Returns the device ID.
uint32_t getDeviceID() const;
/// Returns the MLIRContext.
MLIRContext *getContext() const;
/// Returns the target resource limits.
ResourceLimitsAttr getResourceLimits() const;
TargetEnvAttr getAttr() const { return targetAttr; }
/// Allows implicity converting to the underlying spirv::TargetEnvAttr.
operator TargetEnvAttr() const { return targetAttr; }

View File

@ -288,11 +288,11 @@ spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() {
spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const {
return getImpl()->triple.cast<spirv::VerCapExtAttr>();
}
spirv::Version spirv::TargetEnvAttr::getVersion() {
spirv::Version spirv::TargetEnvAttr::getVersion() const {
return getTripleAttr().getVersion();
}
@ -312,17 +312,19 @@ ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() {
return getTripleAttr().getCapabilitiesAttr();
}
spirv::Vendor spirv::TargetEnvAttr::getVendorID() {
spirv::Vendor spirv::TargetEnvAttr::getVendorID() const {
return getImpl()->vendorID;
}
spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() {
spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() const {
return getImpl()->deviceType;
}
uint32_t spirv::TargetEnvAttr::getDeviceID() { return getImpl()->deviceID; }
uint32_t spirv::TargetEnvAttr::getDeviceID() const {
return getImpl()->deviceID;
}
spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() {
spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
return getImpl()->limits.cast<spirv::ResourceLimitsAttr>();
}

View File

@ -38,7 +38,7 @@ spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr)
}
}
spirv::Version spirv::TargetEnv::getVersion() {
spirv::Version spirv::TargetEnv::getVersion() const {
return targetAttr.getVersion();
}
@ -48,7 +48,7 @@ bool spirv::TargetEnv::allows(spirv::Capability capability) const {
Optional<spirv::Capability>
spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const {
auto chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
return givenCapabilities.count(cap);
});
if (chosen != caps.end())
@ -62,7 +62,7 @@ bool spirv::TargetEnv::allows(spirv::Extension extension) const {
Optional<spirv::Extension>
spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
auto chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
return givenExtensions.count(ext);
});
if (chosen != exts.end())
@ -70,6 +70,22 @@ spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
return llvm::None;
}
spirv::Vendor spirv::TargetEnv::getVendorID() const {
return targetAttr.getVendorID();
}
spirv::DeviceType spirv::TargetEnv::getDeviceType() const {
return targetAttr.getDeviceType();
}
uint32_t spirv::TargetEnv::getDeviceID() const {
return targetAttr.getDeviceID();
}
spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const {
return targetAttr.getResourceLimits();
}
MLIRContext *spirv::TargetEnv::getContext() const {
return targetAttr.getContext();
}