[spirv] Fix bitwidth emulation for Workgroup storage class

If Int16 is not available, 16-bit integers inside Workgroup storage
class should be emulated via 32-bit integers. This was previously
broken because the capability querying logic was incorrectly
intercepting all storage classes where it meant to only handle
interface storage classes. Adjusted where we return to fix this.

Differential Revision: https://reviews.llvm.org/D85308
This commit is contained in:
Lei Zhang 2020-08-05 10:06:00 -04:00
parent b1dac0cfcd
commit 48378a32af
2 changed files with 35 additions and 22 deletions

View File

@ -772,8 +772,12 @@ void ScalarType::getCapabilities(
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
capabilities.push_back(ref); \
} \
} break
/* No requirements for other bitwidths */ \
return; \
}
// This part only handles the cases where special bitwidths appearing in
// interface storage classes.
if (storage) {
switch (*storage) {
STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
@ -782,17 +786,17 @@ void ScalarType::getCapabilities(
STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
StorageUniform16);
case StorageClass::Input:
case StorageClass::Output:
case StorageClass::Output: {
if (bitwidth == 16) {
static const Capability caps[] = {Capability::StorageInputOutput16};
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
capabilities.push_back(ref);
}
break;
return;
}
default:
break;
}
return;
}
#undef STORAGE_CASE

View File

@ -32,25 +32,34 @@ module attributes {
// -----
// TODO: Uncomment this test when the extension handling correctly
// converts an i16 type to i32 type and handles the load/stores
// correctly.
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
}
{
func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
%0 = alloc() : memref<4x5xi16, 3>
%1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3>
store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3>
dealloc %0 : memref<4x5xi16, 3>
return
}
}
// CHECK: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<20 x i32, stride=4>>, Workgroup>
// CHECK_LABEL: spv.func @alloc_dealloc_workgroup_mem
// CHECK: %[[VAR:.+]] = spv._address_of @__workgroup_mem__0
// CHECK: %[[LOC:.+]] = spv.SDiv
// CHECK: %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]]
// CHECK: %{{.+}} = spv.Load "Workgroup" %[[PTR]] : i32
// CHECK: %[[LOC:.+]] = spv.SDiv
// CHECK: %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]]
// CHECK: %{{.+}} = spv.AtomicAnd "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr<i32, Workgroup>
// CHECK: %{{.+}} = spv.AtomicOr "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr<i32, Workgroup>
// module attributes {
// spv.target_env = #spv.target_env<
// #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
// {max_compute_workgroup_invocations = 128 : i32,
// max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
// }
// {
// func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
// %0 = alloc() : memref<4x5xi16, 3>
// %1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3>
// store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3>
// dealloc %0 : memref<4x5xi16, 3>
// return
// }
// }
// -----