[mlir][spirv] Lower memref with dynamic dimensions to runtime arrays

memref types with dynamic dimensions do not have a compile-time
known size. They should be mapped to SPIR-V runtime array types.

Differential Revision: https://reviews.llvm.org/D78197
This commit is contained in:
Lei Zhang 2020-04-15 08:42:28 -04:00
parent a54e18df0a
commit ba49096817
3 changed files with 29 additions and 18 deletions

View File

@ -331,10 +331,11 @@ static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
MemRefType type) {
// TODO(ravishankarm) : Handle dynamic shapes.
if (!type.hasStaticShape()) {
Optional<spirv::StorageClass> storageClass =
SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
if (!storageClass) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: dynamic shape unimplemented\n");
<< type << " illegal: cannot convert memory space\n");
return llvm::None;
}
@ -345,9 +346,26 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
return llvm::None;
}
auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
if (!arrayElemType)
return llvm::None;
Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
if (!scalarSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element size\n");
return llvm::None;
}
if (!type.hasStaticShape()) {
auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *scalarSize);
// Wrap in a struct to satisfy Vulkan interface requirements.
auto structType = spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
}
Optional<int64_t> memrefSize = getTypeNumBytes(type);
if (!scalarSize || !memrefSize) {
if (!memrefSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element count\n");
return llvm::None;
@ -355,17 +373,6 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
auto arrayElemCount = *memrefSize / *scalarSize;
auto storageClass =
SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
if (!storageClass) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot convert memory space\n");
return llvm::None;
}
auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
if (!arrayElemType)
return llvm::None;
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
if (!arrayElemSize) {
LLVM_DEBUG(llvm::dbgs()

View File

@ -486,7 +486,7 @@ func @memref_offset_strides(
// -----
// Check that dynamic shapes are not supported.
// Dynamic shapes
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [], []>,
@ -494,13 +494,17 @@ module attributes {
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
} {
// Check that unranked shapes are not supported.
// CHECK-LABEL: func @unranked_memref
// CHECK-SAME: memref<*xi32>
func @unranked_memref(%arg0: memref<*xi32>) { return }
// CHECK-LABEL: func @dynamic_dim_memref
// CHECK-SAME: memref<8x?xi32>
func @dynamic_dim_memref(%arg0: memref<8x?xi32>) { return }
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<i32, stride=4> [0]>, StorageBuffer>
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<f32, stride=4> [0]>, StorageBuffer>
func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
%arg1: memref<?x?xf32>)
{ return }
} // end module