forked from OSchip/llvm-project
[mlir][spirv] Use functors for default memory space mappings
This makes it easier to use as a utility function to query the mappings, including the reverse. This commit also drops some storage classes that aren't needed for now. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D131411
This commit is contained in:
parent
89b595e141
commit
15135553c4
|
@ -22,9 +22,15 @@ class SPIRVTypeConverter;
|
|||
|
||||
namespace spirv {
|
||||
/// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
|
||||
using MemorySpaceToStorageClassMap = DenseMap<unsigned, spirv::StorageClass>;
|
||||
/// Returns the default map for targeting Vulkan-flavored SPIR-V.
|
||||
MemorySpaceToStorageClassMap getDefaultVulkanStorageClassMap();
|
||||
using MemorySpaceToStorageClassMap =
|
||||
std::function<Optional<spirv::StorageClass>(unsigned)>;
|
||||
|
||||
/// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V
|
||||
/// using the default rule. Returns None if the memory space is unknown.
|
||||
Optional<spirv::StorageClass> mapMemorySpaceToVulkanStorageClass(unsigned);
|
||||
/// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces
|
||||
/// using the default rule. Returns None if the storage class is unsupported.
|
||||
Optional<unsigned> mapVulkanStorageClassToMemorySpace(spirv::StorageClass);
|
||||
|
||||
/// Type converter for converting numeric MemRef memory spaces into SPIR-V
|
||||
/// symbolic ones.
|
||||
|
@ -34,7 +40,7 @@ public:
|
|||
const MemorySpaceToStorageClassMap &memorySpaceMap);
|
||||
|
||||
private:
|
||||
const MemorySpaceToStorageClassMap &memorySpaceMap;
|
||||
MemorySpaceToStorageClassMap memorySpaceMap;
|
||||
};
|
||||
|
||||
/// Creates the target that populates legality of ops with MemRef types.
|
||||
|
|
|
@ -64,7 +64,7 @@ void GPUToSPIRVPass::runOnOperation() {
|
|||
std::unique_ptr<ConversionTarget> target =
|
||||
spirv::getMemorySpaceToStorageClassTarget(*context);
|
||||
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
|
||||
spirv::getDefaultVulkanStorageClassMap();
|
||||
spirv::mapMemorySpaceToVulkanStorageClass;
|
||||
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -30,7 +31,6 @@ using namespace mlir;
|
|||
// Mappings
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() {
|
||||
/// Mapping between SPIR-V storage classes to memref memory spaces.
|
||||
///
|
||||
/// Note: memref does not have a defined semantics for each memory space; it
|
||||
|
@ -47,29 +47,42 @@ spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() {
|
|||
MAP_FN(spirv::StorageClass::PushConstant, 7) \
|
||||
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
|
||||
MAP_FN(spirv::StorageClass::Input, 9) \
|
||||
MAP_FN(spirv::StorageClass::Output, 10) \
|
||||
MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \
|
||||
MAP_FN(spirv::StorageClass::AtomicCounter, 12) \
|
||||
MAP_FN(spirv::StorageClass::Image, 13) \
|
||||
MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \
|
||||
MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \
|
||||
MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \
|
||||
MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \
|
||||
MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \
|
||||
MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \
|
||||
MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \
|
||||
MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \
|
||||
MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \
|
||||
MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
|
||||
MAP_FN(spirv::StorageClass::Output, 10)
|
||||
|
||||
#define STORAGE_SPACE_MAP_FN(storage, space) {space, storage},
|
||||
Optional<spirv::StorageClass>
|
||||
spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) {
|
||||
#define STORAGE_SPACE_MAP_FN(storage, space) \
|
||||
case space: \
|
||||
return storage;
|
||||
|
||||
return {STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)};
|
||||
switch (memorySpace) {
|
||||
STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return llvm::None;
|
||||
|
||||
#undef STORAGE_SPACE_MAP_FN
|
||||
#undef STORAGE_SPACE_MAP_LIST
|
||||
}
|
||||
|
||||
Optional<unsigned>
|
||||
spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
|
||||
#define STORAGE_SPACE_MAP_FN(storage, space) \
|
||||
case storage: \
|
||||
return space;
|
||||
|
||||
switch (storageClass) {
|
||||
STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return llvm::None;
|
||||
|
||||
#undef STORAGE_SPACE_MAP_FN
|
||||
}
|
||||
|
||||
#undef STORAGE_SPACE_MAP_LIST
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type Converter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -91,8 +104,8 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
|
|||
}
|
||||
|
||||
unsigned space = memRefType.getMemorySpaceAsInt();
|
||||
auto it = this->memorySpaceMap.find(space);
|
||||
if (it == this->memorySpaceMap.end()) {
|
||||
auto storage = this->memorySpaceMap(space);
|
||||
if (!storage) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "cannot convert " << memRefType
|
||||
<< " due to being unable to find memory space in map\n");
|
||||
|
@ -100,7 +113,7 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
|
|||
}
|
||||
|
||||
auto storageAttr =
|
||||
spirv::StorageClassAttr::get(memRefType.getContext(), it->second);
|
||||
spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
|
||||
if (auto rankedType = memRefType.dyn_cast<MemRefType>()) {
|
||||
return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
|
||||
rankedType.getLayout(), storageAttr);
|
||||
|
@ -231,16 +244,7 @@ class MapMemRefStorageClassPass final
|
|||
: public MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
|
||||
public:
|
||||
explicit MapMemRefStorageClassPass() {
|
||||
memorySpaceMap = spirv::getDefaultVulkanStorageClassMap();
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "memory space to storage class mapping:\n";
|
||||
if (memorySpaceMap.empty())
|
||||
llvm::dbgs() << " [empty]\n";
|
||||
for (auto kv : memorySpaceMap)
|
||||
llvm::dbgs() << " " << kv.first << " -> "
|
||||
<< spirv::stringifyStorageClass(kv.second) << "\n";
|
||||
});
|
||||
memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
|
||||
}
|
||||
explicit MapMemRefStorageClassPass(
|
||||
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
|
||||
|
|
Loading…
Reference in New Issue