[mlir][spirv] Use functors for default memory space mappings

This makes it easier to use as a utility function to query the
mappings, including the reverse.

This commit also drops some storage classes that aren't needed
for now.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D131411
This commit is contained in:
Lei Zhang 2022-08-09 14:32:22 -04:00
parent 89b595e141
commit 15135553c4
3 changed files with 46 additions and 36 deletions

View File

@ -22,9 +22,15 @@ class SPIRVTypeConverter;
namespace spirv {
/// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
using MemorySpaceToStorageClassMap = DenseMap<unsigned, spirv::StorageClass>;
/// Returns the default map for targeting Vulkan-flavored SPIR-V.
MemorySpaceToStorageClassMap getDefaultVulkanStorageClassMap();
using MemorySpaceToStorageClassMap =
std::function<Optional<spirv::StorageClass>(unsigned)>;
/// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V
/// using the default rule. Returns None if the memory space is unknown.
Optional<spirv::StorageClass> mapMemorySpaceToVulkanStorageClass(unsigned);
/// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces
/// using the default rule. Returns None if the storage class is unsupported.
Optional<unsigned> mapVulkanStorageClassToMemorySpace(spirv::StorageClass);
/// Type converter for converting numeric MemRef memory spaces into SPIR-V
/// symbolic ones.
@ -34,7 +40,7 @@ public:
const MemorySpaceToStorageClassMap &memorySpaceMap);
private:
const MemorySpaceToStorageClassMap &memorySpaceMap;
MemorySpaceToStorageClassMap memorySpaceMap;
};
/// Creates the target that populates legality of ops with MemRef types.

View File

@ -64,7 +64,7 @@ void GPUToSPIRVPass::runOnOperation() {
std::unique_ptr<ConversionTarget> target =
spirv::getMemorySpaceToStorageClassTarget(*context);
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
spirv::getDefaultVulkanStorageClassMap();
spirv::mapMemorySpaceToVulkanStorageClass;
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
RewritePatternSet patterns(context);

View File

@ -16,6 +16,7 @@
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
@ -30,7 +31,6 @@ using namespace mlir;
// Mappings
//===----------------------------------------------------------------------===//
spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() {
/// Mapping between SPIR-V storage classes to memref memory spaces.
///
/// Note: memref does not have a defined semantics for each memory space; it
@ -47,29 +47,42 @@ spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() {
MAP_FN(spirv::StorageClass::PushConstant, 7) \
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
MAP_FN(spirv::StorageClass::Input, 9) \
MAP_FN(spirv::StorageClass::Output, 10) \
MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \
MAP_FN(spirv::StorageClass::AtomicCounter, 12) \
MAP_FN(spirv::StorageClass::Image, 13) \
MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \
MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \
MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \
MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \
MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \
MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \
MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \
MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \
MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \
MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
MAP_FN(spirv::StorageClass::Output, 10)
#define STORAGE_SPACE_MAP_FN(storage, space) {space, storage},
Optional<spirv::StorageClass>
spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) {
#define STORAGE_SPACE_MAP_FN(storage, space) \
case space: \
return storage;
return {STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)};
switch (memorySpace) {
STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
default:
break;
}
return llvm::None;
#undef STORAGE_SPACE_MAP_FN
#undef STORAGE_SPACE_MAP_LIST
}
Optional<unsigned>
spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
#define STORAGE_SPACE_MAP_FN(storage, space) \
case storage: \
return space;
switch (storageClass) {
STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
default:
break;
}
return llvm::None;
#undef STORAGE_SPACE_MAP_FN
}
#undef STORAGE_SPACE_MAP_LIST
//===----------------------------------------------------------------------===//
// Type Converter
//===----------------------------------------------------------------------===//
@ -91,8 +104,8 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
}
unsigned space = memRefType.getMemorySpaceAsInt();
auto it = this->memorySpaceMap.find(space);
if (it == this->memorySpaceMap.end()) {
auto storage = this->memorySpaceMap(space);
if (!storage) {
LLVM_DEBUG(llvm::dbgs()
<< "cannot convert " << memRefType
<< " due to being unable to find memory space in map\n");
@ -100,7 +113,7 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
}
auto storageAttr =
spirv::StorageClassAttr::get(memRefType.getContext(), it->second);
spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
if (auto rankedType = memRefType.dyn_cast<MemRefType>()) {
return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
rankedType.getLayout(), storageAttr);
@ -231,16 +244,7 @@ class MapMemRefStorageClassPass final
: public MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
public:
explicit MapMemRefStorageClassPass() {
memorySpaceMap = spirv::getDefaultVulkanStorageClassMap();
LLVM_DEBUG({
llvm::dbgs() << "memory space to storage class mapping:\n";
if (memorySpaceMap.empty())
llvm::dbgs() << " [empty]\n";
for (auto kv : memorySpaceMap)
llvm::dbgs() << " " << kv.first << " -> "
<< spirv::stringifyStorageClass(kv.second) << "\n";
});
memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
}
explicit MapMemRefStorageClassPass(
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)