From 15135553c4cf34d3915e45b55e915154b33ab67b Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 9 Aug 2022 14:32:22 -0400 Subject: [PATCH] [mlir][spirv] Use functors for default memory space mappings This makes it easier to use as a utility function to query the mappings, including the reverse. This commit also drops some storage classes that aren't needed for now. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D131411 --- .../Conversion/MemRefToSPIRV/MemRefToSPIRV.h | 14 ++-- .../Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp | 2 +- .../MapMemRefStorageClassPass.cpp | 66 ++++++++++--------- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h index 60730246a7c6..a8bb28bdd1aa 100644 --- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h +++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h @@ -22,9 +22,15 @@ class SPIRVTypeConverter; namespace spirv { /// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones. -using MemorySpaceToStorageClassMap = DenseMap; -/// Returns the default map for targeting Vulkan-flavored SPIR-V. -MemorySpaceToStorageClassMap getDefaultVulkanStorageClassMap(); +using MemorySpaceToStorageClassMap = + std::function(unsigned)>; + +/// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V +/// using the default rule. Returns None if the memory space is unknown. +Optional mapMemorySpaceToVulkanStorageClass(unsigned); +/// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces +/// using the default rule. Returns None if the storage class is unsupported. +Optional mapVulkanStorageClassToMemorySpace(spirv::StorageClass); /// Type converter for converting numeric MemRef memory spaces into SPIR-V /// symbolic ones. @@ -34,7 +40,7 @@ public: const MemorySpaceToStorageClassMap &memorySpaceMap); private: - const MemorySpaceToStorageClassMap &memorySpaceMap; + MemorySpaceToStorageClassMap memorySpaceMap; }; /// Creates the target that populates legality of ops with MemRef types. diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index 6d20e989a71a..fb5c89244b54 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -64,7 +64,7 @@ void GPUToSPIRVPass::runOnOperation() { std::unique_ptr target = spirv::getMemorySpaceToStorageClassTarget(*context); spirv::MemorySpaceToStorageClassMap memorySpaceMap = - spirv::getDefaultVulkanStorageClassMap(); + spirv::mapMemorySpaceToVulkanStorageClass; spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp index 535613714d53..e11e4ef085f7 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/Transforms/DialectConversion.h" @@ -30,7 +31,6 @@ using namespace mlir; // Mappings //===----------------------------------------------------------------------===// -spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() { /// Mapping between SPIR-V storage classes to memref memory spaces. /// /// Note: memref does not have a defined semantics for each memory space; it @@ -47,29 +47,42 @@ spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() { MAP_FN(spirv::StorageClass::PushConstant, 7) \ MAP_FN(spirv::StorageClass::UniformConstant, 8) \ MAP_FN(spirv::StorageClass::Input, 9) \ - MAP_FN(spirv::StorageClass::Output, 10) \ - MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ - MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ - MAP_FN(spirv::StorageClass::Image, 13) \ - MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \ - MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \ - MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \ - MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \ - MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \ - MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \ - MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \ - MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \ - MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \ - MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23) + MAP_FN(spirv::StorageClass::Output, 10) -#define STORAGE_SPACE_MAP_FN(storage, space) {space, storage}, +Optional +spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) { +#define STORAGE_SPACE_MAP_FN(storage, space) \ + case space: \ + return storage; - return {STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)}; + switch (memorySpace) { + STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) + default: + break; + } + return llvm::None; #undef STORAGE_SPACE_MAP_FN -#undef STORAGE_SPACE_MAP_LIST } +Optional +spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) { +#define STORAGE_SPACE_MAP_FN(storage, space) \ + case storage: \ + return space; + + switch (storageClass) { + STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) + default: + break; + } + return llvm::None; + +#undef STORAGE_SPACE_MAP_FN +} + +#undef STORAGE_SPACE_MAP_LIST + //===----------------------------------------------------------------------===// // Type Converter //===----------------------------------------------------------------------===// @@ -91,8 +104,8 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter( } unsigned space = memRefType.getMemorySpaceAsInt(); - auto it = this->memorySpaceMap.find(space); - if (it == this->memorySpaceMap.end()) { + auto storage = this->memorySpaceMap(space); + if (!storage) { LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType << " due to being unable to find memory space in map\n"); @@ -100,7 +113,7 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter( } auto storageAttr = - spirv::StorageClassAttr::get(memRefType.getContext(), it->second); + spirv::StorageClassAttr::get(memRefType.getContext(), *storage); if (auto rankedType = memRefType.dyn_cast()) { return MemRefType::get(memRefType.getShape(), memRefType.getElementType(), rankedType.getLayout(), storageAttr); @@ -231,16 +244,7 @@ class MapMemRefStorageClassPass final : public MapMemRefStorageClassBase { public: explicit MapMemRefStorageClassPass() { - memorySpaceMap = spirv::getDefaultVulkanStorageClassMap(); - - LLVM_DEBUG({ - llvm::dbgs() << "memory space to storage class mapping:\n"; - if (memorySpaceMap.empty()) - llvm::dbgs() << " [empty]\n"; - for (auto kv : memorySpaceMap) - llvm::dbgs() << " " << kv.first << " -> " - << spirv::stringifyStorageClass(kv.second) << "\n"; - }); + memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass; } explicit MapMemRefStorageClassPass( const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)