[mlir][spirv] Let SPIRVConversionTarget consider type availability

Previously we only consider the version/extension/capability requirement
on the op itself. This commit updates SPIRVConversionTarget to also
take into consideration the values' types when deciding op legality.

Differential Revision: https://reviews.llvm.org/D75876
This commit is contained in:
Lei Zhang 2020-03-18 09:55:27 -04:00
parent 3b35f9d8b5
commit 67e8690e53
6 changed files with 141 additions and 35 deletions

View File

@ -1,4 +1,4 @@
//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===//
//===- SPIRVLowering.cpp - SPIR-V lowering utilities ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -15,6 +15,7 @@
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include <functional>
@ -443,6 +444,66 @@ spirv::SPIRVConversionTarget::SPIRVConversionTarget(
}
}
/// Checks that `candidates` extension requirements are possible to be satisfied
/// with the given `allowedExtensions`.
///
/// `candidates` is a vector of vector for extension requirements following
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
/// convention.
static LogicalResult checkExtensionRequirements(
Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions,
const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
for (const auto &ors : candidates) {
auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) {
return allowedExtensions.count(ext);
});
if (chosen == ors.end()) {
SmallVector<StringRef, 4> extStrings;
for (spirv::Extension ext : ors)
extStrings.push_back(spirv::stringifyExtension(ext));
LLVM_DEBUG(llvm::dbgs() << op->getName()
<< "illegal: requires at least one extension in ["
<< llvm::join(extStrings, ", ")
<< "] but none allowed in target environment\n");
return failure();
}
}
return success();
}
/// Checks that `candidates`capability requirements are possible to be satisfied
/// with the given `allowedCapabilities`.
///
/// `candidates` is a vector of vector for capability requirements following
/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
/// convention.
static LogicalResult checkCapabilityRequirements(
Operation *op,
const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities,
const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
for (const auto &ors : candidates) {
auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) {
return allowedCapabilities.count(cap);
});
if (chosen == ors.end()) {
SmallVector<StringRef, 4> capStrings;
for (spirv::Capability cap : ors)
capStrings.push_back(spirv::stringifyCapability(cap));
LLVM_DEBUG(llvm::dbgs()
<< op->getName()
<< "illegal: requires at least one capability in ["
<< llvm::join(capStrings, ", ")
<< "] but none allowed in target environment\n");
return failure();
}
}
return success();
}
bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
// Make sure this op is available at the given version. Ops not implementing
// QueryMinVersionInterface/QueryMaxVersionInterface are available to all
@ -464,38 +525,47 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
return false;
}
// Make sure this op's required extensions are allowed to use. For each op,
// we return a vector of vector for its extension requirements following
// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
// convention. Ops not implementing QueryExtensionInterface do not require
// extensions to be available.
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
auto exts = extensions.getExtensions();
for (const auto &ors : exts)
if (llvm::all_of(ors, [this](spirv::Extension ext) {
return this->givenExtensions.count(ext) == 0;
})) {
LLVM_DEBUG(llvm::dbgs() << op->getName()
<< " illegal: missing required extension\n");
return false;
}
}
// Make sure this op's required extensions are allowed to use. Ops not
// implementing QueryExtensionInterface do not require extensions to be
// available.
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
if (failed(checkExtensionRequirements(op, this->givenExtensions,
extensions.getExtensions())))
return false;
// Make sure this op's required extensions are allowed to use. For each op,
// we return a vector of vector for its capability requirements following
// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
// convention. Ops not implementing QueryExtensionInterface do not require
// extensions to be available.
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
auto caps = capabilities.getCapabilities();
for (const auto &ors : caps)
if (llvm::all_of(ors, [this](spirv::Capability cap) {
return this->givenCapabilities.count(cap) == 0;
})) {
LLVM_DEBUG(llvm::dbgs() << op->getName()
<< " illegal: missing required capability\n");
return false;
}
// Make sure this op's required extensions are allowed to use. Ops not
// implementing QueryCapabilityInterface do not require capabilities to be
// available.
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
capabilities.getCapabilities())))
return false;
SmallVector<Type, 4> valueTypes;
valueTypes.append(op->operand_type_begin(), op->operand_type_end());
valueTypes.append(op->result_type_begin(), op->result_type_end());
// Special treatment for global variables, whose type requirements are
// conveyed by type attributes.
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
valueTypes.push_back(globalVar.type());
// Make sure the op's operands/results use types that are allowed by the
// target environment.
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
typeExtensions.clear();
valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
if (failed(checkExtensionRequirements(op, this->givenExtensions,
typeExtensions)))
return false;
typeCapabilities.clear();
valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
typeCapabilities)))
return false;
}
return true;

View File

@ -1,6 +1,12 @@
// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
module attributes {gpu.container_module} {
module attributes {
gpu.container_module,
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
} {
func @main(%arg0 : memref<10xf32>, %arg1 : i1) {
%c0 = constant 1 : index
"gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "kernel_simple_selection", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, i1) -> ()

View File

@ -1,6 +1,12 @@
// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
module attributes {gpu.container_module} {
module attributes {
gpu.container_module,
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
} {
func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) {
%c0 = constant 0 : index
%c12 = constant 12 : index

View File

@ -1,6 +1,12 @@
// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
module attributes {gpu.container_module} {
module attributes {
gpu.container_module,
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
} {
func @loop(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>) {
%c0 = constant 1 : index
"gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "loop_kernel", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, memref<10xf32>) -> ()

View File

@ -1,5 +1,12 @@
// RUN: mlir-opt -convert-std-to-spirv %s -o - | FileCheck %s
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader, Int64, Float64], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
} {
//===----------------------------------------------------------------------===//
// std binary arithmetic ops
//===----------------------------------------------------------------------===//
@ -366,3 +373,5 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
store %0, %arg1[] : memref<i32>
return
}
} // end module

View File

@ -4,6 +4,13 @@
// the desired output. Adding all of patterns within a single pass does
// not seem to work.
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
} {
//===----------------------------------------------------------------------===//
// std.subview
//===----------------------------------------------------------------------===//
@ -51,3 +58,5 @@ func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : i
store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
return
}
} // end module