forked from OSchip/llvm-project
[mlir][spirv] Let SPIRVConversionTarget consider type availability
Previously we only consider the version/extension/capability requirement on the op itself. This commit updates SPIRVConversionTarget to also take into consideration the values' types when deciding op legality. Differential Revision: https://reviews.llvm.org/D75876
This commit is contained in:
parent
3b35f9d8b5
commit
67e8690e53
|
@ -1,4 +1,4 @@
|
|||
//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===//
|
||||
//===- SPIRVLowering.cpp - SPIR-V lowering utilities ----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include <functional>
|
||||
|
@ -443,6 +444,66 @@ spirv::SPIRVConversionTarget::SPIRVConversionTarget(
|
|||
}
|
||||
}
|
||||
|
||||
/// Checks that `candidates` extension requirements are possible to be satisfied
|
||||
/// with the given `allowedExtensions`.
|
||||
///
|
||||
/// `candidates` is a vector of vector for extension requirements following
|
||||
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
|
||||
/// convention.
|
||||
static LogicalResult checkExtensionRequirements(
|
||||
Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions,
|
||||
const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
|
||||
for (const auto &ors : candidates) {
|
||||
auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) {
|
||||
return allowedExtensions.count(ext);
|
||||
});
|
||||
|
||||
if (chosen == ors.end()) {
|
||||
SmallVector<StringRef, 4> extStrings;
|
||||
for (spirv::Extension ext : ors)
|
||||
extStrings.push_back(spirv::stringifyExtension(ext));
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << op->getName()
|
||||
<< "illegal: requires at least one extension in ["
|
||||
<< llvm::join(extStrings, ", ")
|
||||
<< "] but none allowed in target environment\n");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Checks that `candidates`capability requirements are possible to be satisfied
|
||||
/// with the given `allowedCapabilities`.
|
||||
///
|
||||
/// `candidates` is a vector of vector for capability requirements following
|
||||
/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
|
||||
/// convention.
|
||||
static LogicalResult checkCapabilityRequirements(
|
||||
Operation *op,
|
||||
const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities,
|
||||
const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
|
||||
for (const auto &ors : candidates) {
|
||||
auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) {
|
||||
return allowedCapabilities.count(cap);
|
||||
});
|
||||
|
||||
if (chosen == ors.end()) {
|
||||
SmallVector<StringRef, 4> capStrings;
|
||||
for (spirv::Capability cap : ors)
|
||||
capStrings.push_back(spirv::stringifyCapability(cap));
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< op->getName()
|
||||
<< "illegal: requires at least one capability in ["
|
||||
<< llvm::join(capStrings, ", ")
|
||||
<< "] but none allowed in target environment\n");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
|
||||
// Make sure this op is available at the given version. Ops not implementing
|
||||
// QueryMinVersionInterface/QueryMaxVersionInterface are available to all
|
||||
|
@ -464,38 +525,47 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Make sure this op's required extensions are allowed to use. For each op,
|
||||
// we return a vector of vector for its extension requirements following
|
||||
// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
|
||||
// convention. Ops not implementing QueryExtensionInterface do not require
|
||||
// extensions to be available.
|
||||
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
|
||||
auto exts = extensions.getExtensions();
|
||||
for (const auto &ors : exts)
|
||||
if (llvm::all_of(ors, [this](spirv::Extension ext) {
|
||||
return this->givenExtensions.count(ext) == 0;
|
||||
})) {
|
||||
LLVM_DEBUG(llvm::dbgs() << op->getName()
|
||||
<< " illegal: missing required extension\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Make sure this op's required extensions are allowed to use. Ops not
|
||||
// implementing QueryExtensionInterface do not require extensions to be
|
||||
// available.
|
||||
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
|
||||
if (failed(checkExtensionRequirements(op, this->givenExtensions,
|
||||
extensions.getExtensions())))
|
||||
return false;
|
||||
|
||||
// Make sure this op's required extensions are allowed to use. For each op,
|
||||
// we return a vector of vector for its capability requirements following
|
||||
// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
|
||||
// convention. Ops not implementing QueryExtensionInterface do not require
|
||||
// extensions to be available.
|
||||
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
|
||||
auto caps = capabilities.getCapabilities();
|
||||
for (const auto &ors : caps)
|
||||
if (llvm::all_of(ors, [this](spirv::Capability cap) {
|
||||
return this->givenCapabilities.count(cap) == 0;
|
||||
})) {
|
||||
LLVM_DEBUG(llvm::dbgs() << op->getName()
|
||||
<< " illegal: missing required capability\n");
|
||||
return false;
|
||||
}
|
||||
// Make sure this op's required extensions are allowed to use. Ops not
|
||||
// implementing QueryCapabilityInterface do not require capabilities to be
|
||||
// available.
|
||||
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
|
||||
if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
|
||||
capabilities.getCapabilities())))
|
||||
return false;
|
||||
|
||||
SmallVector<Type, 4> valueTypes;
|
||||
valueTypes.append(op->operand_type_begin(), op->operand_type_end());
|
||||
valueTypes.append(op->result_type_begin(), op->result_type_end());
|
||||
|
||||
// Special treatment for global variables, whose type requirements are
|
||||
// conveyed by type attributes.
|
||||
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
|
||||
valueTypes.push_back(globalVar.type());
|
||||
|
||||
// Make sure the op's operands/results use types that are allowed by the
|
||||
// target environment.
|
||||
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
|
||||
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
|
||||
for (Type valueType : valueTypes) {
|
||||
typeExtensions.clear();
|
||||
valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
|
||||
if (failed(checkExtensionRequirements(op, this->givenExtensions,
|
||||
typeExtensions)))
|
||||
return false;
|
||||
|
||||
typeCapabilities.clear();
|
||||
valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
|
||||
if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
|
||||
typeCapabilities)))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
|
||||
|
||||
module attributes {gpu.container_module} {
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
func @main(%arg0 : memref<10xf32>, %arg1 : i1) {
|
||||
%c0 = constant 1 : index
|
||||
"gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "kernel_simple_selection", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, i1) -> ()
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
|
||||
|
||||
module attributes {gpu.container_module} {
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%c12 = constant 12 : index
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
|
||||
|
||||
module attributes {gpu.container_module} {
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
func @loop(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>) {
|
||||
%c0 = constant 1 : index
|
||||
"gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "loop_kernel", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, memref<10xf32>) -> ()
|
||||
|
|
|
@ -1,5 +1,12 @@
|
|||
// RUN: mlir-opt -convert-std-to-spirv %s -o - | FileCheck %s
|
||||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Shader, Int64, Float64], [SPV_KHR_storage_buffer_storage_class]>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// std binary arithmetic ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -366,3 +373,5 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
|
|||
store %0, %arg1[] : memref<i32>
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
|
|
@ -4,6 +4,13 @@
|
|||
// the desired output. Adding all of patterns within a single pass does
|
||||
// not seem to work.
|
||||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// std.subview
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -51,3 +58,5 @@ func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : i
|
|||
store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
|
Loading…
Reference in New Issue