From 67e8690e53c341ba433f9c2de3f5a16b8beb7f0b Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 18 Mar 2020 09:55:27 -0400 Subject: [PATCH] [mlir][spirv] Let SPIRVConversionTarget consider type availability Previously we only consider the version/extension/capability requirement on the op itself. This commit updates SPIRVConversionTarget to also take into consideration the values' types when deciding op legality. Differential Revision: https://reviews.llvm.org/D75876 --- mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 134 +++++++++++++----- mlir/test/Conversion/GPUToSPIRV/if.mlir | 8 +- .../Conversion/GPUToSPIRV/load-store.mlir | 8 +- mlir/test/Conversion/GPUToSPIRV/loop.mlir | 8 +- .../StandardToSPIRV/std-to-spirv.mlir | 9 ++ .../StandardToSPIRV/subview-to-spirv.mlir | 9 ++ 6 files changed, 141 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index e9250c56a1d2..6d73432fead4 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -1,4 +1,4 @@ -//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===// +//===- SPIRVLowering.cpp - SPIR-V lowering utilities ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -15,6 +15,7 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #include @@ -443,6 +444,66 @@ spirv::SPIRVConversionTarget::SPIRVConversionTarget( } } +/// Checks that `candidates` extension requirements are possible to be satisfied +/// with the given `allowedExtensions`. +/// +/// `candidates` is a vector of vector for extension requirements following +/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) +/// convention. +static LogicalResult checkExtensionRequirements( + Operation *op, const llvm::SmallSet &allowedExtensions, + const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { + for (const auto &ors : candidates) { + auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) { + return allowedExtensions.count(ext); + }); + + if (chosen == ors.end()) { + SmallVector extStrings; + for (spirv::Extension ext : ors) + extStrings.push_back(spirv::stringifyExtension(ext)); + + LLVM_DEBUG(llvm::dbgs() << op->getName() + << "illegal: requires at least one extension in [" + << llvm::join(extStrings, ", ") + << "] but none allowed in target environment\n"); + return failure(); + } + } + return success(); +} + +/// Checks that `candidates`capability requirements are possible to be satisfied +/// with the given `allowedCapabilities`. +/// +/// `candidates` is a vector of vector for capability requirements following +/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) +/// convention. +static LogicalResult checkCapabilityRequirements( + Operation *op, + const llvm::SmallSet &allowedCapabilities, + const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { + for (const auto &ors : candidates) { + auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) { + return allowedCapabilities.count(cap); + }); + + if (chosen == ors.end()) { + SmallVector capStrings; + for (spirv::Capability cap : ors) + capStrings.push_back(spirv::stringifyCapability(cap)); + + LLVM_DEBUG(llvm::dbgs() + << op->getName() + << "illegal: requires at least one capability in [" + << llvm::join(capStrings, ", ") + << "] but none allowed in target environment\n"); + return failure(); + } + } + return success(); +} + bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { // Make sure this op is available at the given version. Ops not implementing // QueryMinVersionInterface/QueryMaxVersionInterface are available to all @@ -464,38 +525,47 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { return false; } - // Make sure this op's required extensions are allowed to use. For each op, - // we return a vector of vector for its extension requirements following - // ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) - // convention. Ops not implementing QueryExtensionInterface do not require - // extensions to be available. - if (auto extensions = dyn_cast(op)) { - auto exts = extensions.getExtensions(); - for (const auto &ors : exts) - if (llvm::all_of(ors, [this](spirv::Extension ext) { - return this->givenExtensions.count(ext) == 0; - })) { - LLVM_DEBUG(llvm::dbgs() << op->getName() - << " illegal: missing required extension\n"); - return false; - } - } + // Make sure this op's required extensions are allowed to use. Ops not + // implementing QueryExtensionInterface do not require extensions to be + // available. + if (auto extensions = dyn_cast(op)) + if (failed(checkExtensionRequirements(op, this->givenExtensions, + extensions.getExtensions()))) + return false; - // Make sure this op's required extensions are allowed to use. For each op, - // we return a vector of vector for its capability requirements following - // ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D)) - // convention. Ops not implementing QueryExtensionInterface do not require - // extensions to be available. - if (auto capabilities = dyn_cast(op)) { - auto caps = capabilities.getCapabilities(); - for (const auto &ors : caps) - if (llvm::all_of(ors, [this](spirv::Capability cap) { - return this->givenCapabilities.count(cap) == 0; - })) { - LLVM_DEBUG(llvm::dbgs() << op->getName() - << " illegal: missing required capability\n"); - return false; - } + // Make sure this op's required extensions are allowed to use. Ops not + // implementing QueryCapabilityInterface do not require capabilities to be + // available. + if (auto capabilities = dyn_cast(op)) + if (failed(checkCapabilityRequirements(op, this->givenCapabilities, + capabilities.getCapabilities()))) + return false; + + SmallVector valueTypes; + valueTypes.append(op->operand_type_begin(), op->operand_type_end()); + valueTypes.append(op->result_type_begin(), op->result_type_end()); + + // Special treatment for global variables, whose type requirements are + // conveyed by type attributes. + if (auto globalVar = dyn_cast(op)) + valueTypes.push_back(globalVar.type()); + + // Make sure the op's operands/results use types that are allowed by the + // target environment. + SmallVector, 4> typeExtensions; + SmallVector, 8> typeCapabilities; + for (Type valueType : valueTypes) { + typeExtensions.clear(); + valueType.cast().getExtensions(typeExtensions); + if (failed(checkExtensionRequirements(op, this->givenExtensions, + typeExtensions))) + return false; + + typeCapabilities.clear(); + valueType.cast().getCapabilities(typeCapabilities); + if (failed(checkCapabilityRequirements(op, this->givenCapabilities, + typeCapabilities))) + return false; } return true; diff --git a/mlir/test/Conversion/GPUToSPIRV/if.mlir b/mlir/test/Conversion/GPUToSPIRV/if.mlir index 1585c53116c5..8a8aa1c88813 100644 --- a/mlir/test/Conversion/GPUToSPIRV/if.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/if.mlir @@ -1,6 +1,12 @@ // RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s -module attributes {gpu.container_module} { +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { func @main(%arg0 : memref<10xf32>, %arg1 : i1) { %c0 = constant 1 : index "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "kernel_simple_selection", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, i1) -> () diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir index d0224fd16e02..05c9d90c498c 100644 --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -1,6 +1,12 @@ // RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s -module attributes {gpu.container_module} { +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) { %c0 = constant 0 : index %c12 = constant 12 : index diff --git a/mlir/test/Conversion/GPUToSPIRV/loop.mlir b/mlir/test/Conversion/GPUToSPIRV/loop.mlir index 7044d5474d3c..8adc5e355f08 100644 --- a/mlir/test/Conversion/GPUToSPIRV/loop.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/loop.mlir @@ -1,6 +1,12 @@ // RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s -module attributes {gpu.container_module} { +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { func @loop(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>) { %c0 = constant 1 : index "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "loop_kernel", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, memref<10xf32>) -> () diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir index 9b8d695af422..26e2ea42d3a2 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -1,5 +1,12 @@ // RUN: mlir-opt -convert-std-to-spirv %s -o - | FileCheck %s +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + //===----------------------------------------------------------------------===// // std binary arithmetic ops //===----------------------------------------------------------------------===// @@ -366,3 +373,5 @@ func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { store %0, %arg1[] : memref return } + +} // end module diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir index c9d1195bc056..cc94c089dfb2 100644 --- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir @@ -4,6 +4,13 @@ // the desired output. Adding all of patterns within a single pass does // not seem to work. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + //===----------------------------------------------------------------------===// // std.subview //===----------------------------------------------------------------------===// @@ -51,3 +58,5 @@ func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : i store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> return } + +} // end module