[mlir][spirv] Only attach struct offset for required storage classes

Per the SPIR-V spec "2.16.2. Validation Rules for Shader Capabilities":

  Composite objects in the StorageBuffer, PhysicalStorageBuffer,
  Uniform, and PushConstant Storage Classes must be explicitly
  laid out.

For other cases we don't need to attach the struct offsets.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D100386
This commit is contained in:
Lei Zhang 2021-04-13 15:18:32 -04:00
parent 6ddd8c28b7
commit 5b15fe9334
3 changed files with 36 additions and 24 deletions

View File

@ -84,6 +84,30 @@ static LogicalResult checkCapabilityRequirements(
return success();
}
/// Returns true if the given `storageClass` needs explicit layout when used in
/// Shader environments.
static bool needsExplicitLayout(spirv::StorageClass storageClass) {
switch (storageClass) {
case spirv::StorageClass::PhysicalStorageBuffer:
case spirv::StorageClass::PushConstant:
case spirv::StorageClass::StorageBuffer:
case spirv::StorageClass::Uniform:
return true;
default:
return false;
}
}
/// Wraps the given `elementType` in a struct and gets the pointer to the
/// struct. This is used to satisfy Vulkan interface requirements.
static spirv::PointerType
wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
auto structType = needsExplicitLayout(storageClass)
? spirv::StructType::get(elementType, /*offsetInfo=*/0)
: spirv::StructType::get(elementType);
return spirv::PointerType::get(structType, storageClass);
}
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
@ -392,12 +416,7 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
auto arrayType =
spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
// workgroup storage class do not need the struct to be laid out explicitly.
auto structType = *storageClass == spirv::StorageClass::Workgroup
? spirv::StructType::get(arrayType)
: spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
return wrapInStructAndGetPointer(arrayType, *storageClass);
}
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
@ -452,9 +471,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
if (!type.hasStaticShape()) {
auto arrayType =
spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
// Wrap in a struct to satisfy Vulkan interface requirements.
auto structType = spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
return wrapInStructAndGetPointer(arrayType, *storageClass);
}
Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
@ -470,12 +487,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
auto arrayType =
spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
// workgroup storage class do not need the struct to be laid out explicitly.
auto structType = *storageClass == spirv::StorageClass::Workgroup
? spirv::StructType::get(arrayType)
: spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
return wrapInStructAndGetPointer(arrayType, *storageClass);
}
SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,

View File

@ -9,7 +9,7 @@ module attributes {
// CHECK: spv.func
// CHECK-SAME: {{%.*}}: f32
// CHECK-NOT: spv.interface_var_abi
// CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32, stride=4> [0])>, CrossWorkgroup>
// CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32, stride=4>)>, CrossWorkgroup>
// CHECK-NOT: spv.interface_var_abi
// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel

View File

@ -337,13 +337,13 @@ func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }
func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Input
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, Input>
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4>)>, Input>
// NOEMU-LABEL: func @memref_16bit_Input
// NOEMU-SAME: memref<16xf16, 9>
func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Output
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, Output>
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4>)>, Output>
// NOEMU-LABEL: func @memref_16bit_Output
// NOEMU-SAME: memref<16xf16, 10>
func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
@ -451,15 +451,15 @@ module attributes {
} {
// CHECK-LABEL: spv.func @memref_16bit_Input
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Input>
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2>)>, Input>
// NOEMU-LABEL: spv.func @memref_16bit_Input
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Input>
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2>)>, Input>
func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Output
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Output>
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2>)>, Output>
// NOEMU-LABEL: spv.func @memref_16bit_Output
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Output>
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2>)>, Output>
func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
} // end module
@ -563,13 +563,13 @@ func @memref_16bit_Uniform(%arg0: memref<?xsi16, 4>) { return }
func @memref_16bit_PushConstant(%arg0: memref<?xui16, 7>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Input
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, Input>
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4>)>, Input>
// NOEMU-LABEL: func @memref_16bit_Input
// NOEMU-SAME: memref<?xf16, 9>
func @memref_16bit_Input(%arg3: memref<?xf16, 9>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Output
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, Output>
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4>)>, Output>
// NOEMU-LABEL: func @memref_16bit_Output
// NOEMU-SAME: memref<?xf16, 10>
func @memref_16bit_Output(%arg4: memref<?xf16, 10>) { return }