forked from OSchip/llvm-project
[mlir][spirv] Only attach struct offset for required storage classes
Per the SPIR-V spec "2.16.2. Validation Rules for Shader Capabilities": Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant Storage Classes must be explicitly laid out. For other cases we don't need to attach the struct offsets. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D100386
This commit is contained in:
parent
6ddd8c28b7
commit
5b15fe9334
|
@ -84,6 +84,30 @@ static LogicalResult checkCapabilityRequirements(
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Returns true if the given `storageClass` needs explicit layout when used in
|
||||
/// Shader environments.
|
||||
static bool needsExplicitLayout(spirv::StorageClass storageClass) {
|
||||
switch (storageClass) {
|
||||
case spirv::StorageClass::PhysicalStorageBuffer:
|
||||
case spirv::StorageClass::PushConstant:
|
||||
case spirv::StorageClass::StorageBuffer:
|
||||
case spirv::StorageClass::Uniform:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps the given `elementType` in a struct and gets the pointer to the
|
||||
/// struct. This is used to satisfy Vulkan interface requirements.
|
||||
static spirv::PointerType
|
||||
wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
|
||||
auto structType = needsExplicitLayout(storageClass)
|
||||
? spirv::StructType::get(elementType, /*offsetInfo=*/0)
|
||||
: spirv::StructType::get(elementType);
|
||||
return spirv::PointerType::get(structType, storageClass);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type Conversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -392,12 +416,7 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
|
|||
auto arrayType =
|
||||
spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
|
||||
|
||||
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
|
||||
// workgroup storage class do not need the struct to be laid out explicitly.
|
||||
auto structType = *storageClass == spirv::StorageClass::Workgroup
|
||||
? spirv::StructType::get(arrayType)
|
||||
: spirv::StructType::get(arrayType, 0);
|
||||
return spirv::PointerType::get(structType, *storageClass);
|
||||
return wrapInStructAndGetPointer(arrayType, *storageClass);
|
||||
}
|
||||
|
||||
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
|
||||
|
@ -452,9 +471,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
|
|||
if (!type.hasStaticShape()) {
|
||||
auto arrayType =
|
||||
spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
|
||||
// Wrap in a struct to satisfy Vulkan interface requirements.
|
||||
auto structType = spirv::StructType::get(arrayType, 0);
|
||||
return spirv::PointerType::get(structType, *storageClass);
|
||||
return wrapInStructAndGetPointer(arrayType, *storageClass);
|
||||
}
|
||||
|
||||
Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
|
||||
|
@ -470,12 +487,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
|
|||
auto arrayType =
|
||||
spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
|
||||
|
||||
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
|
||||
// workgroup storage class do not need the struct to be laid out explicitly.
|
||||
auto structType = *storageClass == spirv::StorageClass::Workgroup
|
||||
? spirv::StructType::get(arrayType)
|
||||
: spirv::StructType::get(arrayType, 0);
|
||||
return spirv::PointerType::get(structType, *storageClass);
|
||||
return wrapInStructAndGetPointer(arrayType, *storageClass);
|
||||
}
|
||||
|
||||
SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
|
||||
|
|
|
@ -9,7 +9,7 @@ module attributes {
|
|||
// CHECK: spv.func
|
||||
// CHECK-SAME: {{%.*}}: f32
|
||||
// CHECK-NOT: spv.interface_var_abi
|
||||
// CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32, stride=4> [0])>, CrossWorkgroup>
|
||||
// CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32, stride=4>)>, CrossWorkgroup>
|
||||
// CHECK-NOT: spv.interface_var_abi
|
||||
// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
|
||||
gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel
|
||||
|
|
|
@ -337,13 +337,13 @@ func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }
|
|||
func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }
|
||||
|
||||
// CHECK-LABEL: spv.func @memref_16bit_Input
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, Input>
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4>)>, Input>
|
||||
// NOEMU-LABEL: func @memref_16bit_Input
|
||||
// NOEMU-SAME: memref<16xf16, 9>
|
||||
func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
|
||||
|
||||
// CHECK-LABEL: spv.func @memref_16bit_Output
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, Output>
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4>)>, Output>
|
||||
// NOEMU-LABEL: func @memref_16bit_Output
|
||||
// NOEMU-SAME: memref<16xf16, 10>
|
||||
func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
|
||||
|
@ -451,15 +451,15 @@ module attributes {
|
|||
} {
|
||||
|
||||
// CHECK-LABEL: spv.func @memref_16bit_Input
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Input>
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2>)>, Input>
|
||||
// NOEMU-LABEL: spv.func @memref_16bit_Input
|
||||
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Input>
|
||||
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2>)>, Input>
|
||||
func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
|
||||
|
||||
// CHECK-LABEL: spv.func @memref_16bit_Output
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Output>
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2>)>, Output>
|
||||
// NOEMU-LABEL: spv.func @memref_16bit_Output
|
||||
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Output>
|
||||
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2>)>, Output>
|
||||
func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
|
||||
|
||||
} // end module
|
||||
|
@ -563,13 +563,13 @@ func @memref_16bit_Uniform(%arg0: memref<?xsi16, 4>) { return }
|
|||
func @memref_16bit_PushConstant(%arg0: memref<?xui16, 7>) { return }
|
||||
|
||||
// CHECK-LABEL: spv.func @memref_16bit_Input
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, Input>
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4>)>, Input>
|
||||
// NOEMU-LABEL: func @memref_16bit_Input
|
||||
// NOEMU-SAME: memref<?xf16, 9>
|
||||
func @memref_16bit_Input(%arg3: memref<?xf16, 9>) { return }
|
||||
|
||||
// CHECK-LABEL: spv.func @memref_16bit_Output
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, Output>
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4>)>, Output>
|
||||
// NOEMU-LABEL: func @memref_16bit_Output
|
||||
// NOEMU-SAME: memref<?xf16, 10>
|
||||
func @memref_16bit_Output(%arg4: memref<?xf16, 10>) { return }
|
||||
|
|
Loading…
Reference in New Issue