forked from OSchip/llvm-project
[mlir][spirv] Lower memref with dynamic dimensions to runtime arrays
memref types with dynamic dimensions do not have a compile-time known size. They should be mapped to SPIR-V runtime array types. Differential Revision: https://reviews.llvm.org/D78197
This commit is contained in:
parent
a54e18df0a
commit
ba49096817
|
@ -331,10 +331,11 @@ static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
|
|||
|
||||
static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
|
||||
MemRefType type) {
|
||||
// TODO(ravishankarm) : Handle dynamic shapes.
|
||||
if (!type.hasStaticShape()) {
|
||||
Optional<spirv::StorageClass> storageClass =
|
||||
SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
|
||||
if (!storageClass) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< type << " illegal: dynamic shape unimplemented\n");
|
||||
<< type << " illegal: cannot convert memory space\n");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
|
@ -345,9 +346,26 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
|
|||
return llvm::None;
|
||||
}
|
||||
|
||||
auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
|
||||
if (!arrayElemType)
|
||||
return llvm::None;
|
||||
|
||||
Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
|
||||
if (!scalarSize) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< type << " illegal: cannot deduce element size\n");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
if (!type.hasStaticShape()) {
|
||||
auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *scalarSize);
|
||||
// Wrap in a struct to satisfy Vulkan interface requirements.
|
||||
auto structType = spirv::StructType::get(arrayType, 0);
|
||||
return spirv::PointerType::get(structType, *storageClass);
|
||||
}
|
||||
|
||||
Optional<int64_t> memrefSize = getTypeNumBytes(type);
|
||||
if (!scalarSize || !memrefSize) {
|
||||
if (!memrefSize) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< type << " illegal: cannot deduce element count\n");
|
||||
return llvm::None;
|
||||
|
@ -355,17 +373,6 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
|
|||
|
||||
auto arrayElemCount = *memrefSize / *scalarSize;
|
||||
|
||||
auto storageClass =
|
||||
SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
|
||||
if (!storageClass) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< type << " illegal: cannot convert memory space\n");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
|
||||
if (!arrayElemType)
|
||||
return llvm::None;
|
||||
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
|
||||
if (!arrayElemSize) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
|
|
|
@ -486,7 +486,7 @@ func @memref_offset_strides(
|
|||
|
||||
// -----
|
||||
|
||||
// Check that dynamic shapes are not supported.
|
||||
// Dynamic shapes
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [], []>,
|
||||
|
@ -494,13 +494,17 @@ module attributes {
|
|||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
|
||||
// Check that unranked shapes are not supported.
|
||||
// CHECK-LABEL: func @unranked_memref
|
||||
// CHECK-SAME: memref<*xi32>
|
||||
func @unranked_memref(%arg0: memref<*xi32>) { return }
|
||||
|
||||
// CHECK-LABEL: func @dynamic_dim_memref
|
||||
// CHECK-SAME: memref<8x?xi32>
|
||||
func @dynamic_dim_memref(%arg0: memref<8x?xi32>) { return }
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<i32, stride=4> [0]>, StorageBuffer>
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<f32, stride=4> [0]>, StorageBuffer>
|
||||
func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
|
||||
%arg1: memref<?x?xf32>)
|
||||
{ return }
|
||||
|
||||
} // end module
|
||||
|
||||
|
|
Loading…
Reference in New Issue