[mlir][spirv] Handle dynamic/static cases differntly for kernel capability

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D134908
This commit is contained in:
Nirvedh Meshram 2022-09-29 11:14:18 -07:00
parent 682c95672b
commit d6de6dde82
5 changed files with 71 additions and 36 deletions

View File

@ -335,7 +335,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
.getPointeeType();
Type dstType;
if (typeConverter.allows(spirv::Capability::Kernel)) {
// For OpenCL Kernel, pointer will be directly pointing to the element.
if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
dstType = arrayType.getElementType();
else
dstType = pointeeType;
} else {
// For Vulkan we need to extract element from wrapping struct and array.
@ -464,7 +466,9 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
.getPointeeType();
Type dstType;
if (typeConverter.allows(spirv::Capability::Kernel)) {
// For OpenCL Kernel, pointer will be directly pointing to the element.
if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
dstType = arrayType.getElementType();
else
dstType = pointeeType;
} else {
// For Vulkan we need to extract element from wrapping struct and array.

View File

@ -338,15 +338,16 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
// For OpenCL Kernel we can just emit a pointer pointing to the element.
if (!type.hasStaticShape()) {
// For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
// to the element.
if (targetEnv.allows(spirv::Capability::Kernel))
return spirv::PointerType::get(arrayElemType, storageClass);
// For Vulkan we need extra wrapping struct and array to satisfy interface
// needs.
if (!type.hasStaticShape()) {
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
// For Vulkan we need extra wrapping struct and array to satisfy interface
// needs.
return wrapInStructAndGetPointer(arrayType, storageClass);
}
@ -354,7 +355,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
if (targetEnv.allows(spirv::Capability::Kernel))
return spirv::PointerType::get(arrayType, storageClass);
return wrapInStructAndGetPointer(arrayType, storageClass);
}
@ -403,15 +405,16 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
// For OpenCL Kernel we can just emit a pointer pointing to the element.
if (!type.hasStaticShape()) {
// For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
// to the element.
if (targetEnv.allows(spirv::Capability::Kernel))
return spirv::PointerType::get(arrayElemType, storageClass);
// For Vulkan we need extra wrapping struct and array to satisfy interface
// needs.
if (!type.hasStaticShape()) {
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
// For Vulkan we need extra wrapping struct and array to satisfy interface
// needs.
return wrapInStructAndGetPointer(arrayType, storageClass);
}
@ -425,7 +428,8 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
if (targetEnv.allows(spirv::Capability::Kernel))
return spirv::PointerType::get(arrayType, storageClass);
return wrapInStructAndGetPointer(arrayType, storageClass);
}
@ -776,15 +780,20 @@ Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
auto indexType = typeConverter.getIndexType();
SmallVector<Value, 2> linearizedIndices;
auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
Value linearIndex;
if (baseType.getRank() == 0) {
linearIndex = zero;
linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
} else {
linearIndex =
linearizeIndex(indices, strides, offset, indexType, loc, builder);
}
Type pointeeType =
basePtr.getType().cast<spirv::PointerType>().getPointeeType();
if (pointeeType.isa<spirv::ArrayType>()) {
linearizedIndices.push_back(linearIndex);
return builder.create<spirv::AccessChainOp>(loc, basePtr,
linearizedIndices);
}
return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
linearizedIndices);
}

View File

@ -9,7 +9,7 @@ module attributes {
// CHECK: spirv.func
// CHECK-SAME: {{%.*}}: f32
// CHECK-NOT: spirv.interface_var_abi
// CHECK-SAME: {{%.*}}: !spirv.ptr<f32, CrossWorkgroup>
// CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.array<12 x f32>, CrossWorkgroup>
// CHECK-NOT: spirv.interface_var_abi
// CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spirv.storage_class<CrossWorkgroup>>) kernel

View File

@ -155,3 +155,27 @@ module attributes {
return
}
}
// -----
module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Kernel], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
}
{
func.func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
%0 = memref.alloc() : memref<4x5xf32, #spirv.storage_class<Workgroup>>
%1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
memref.dealloc %0 : memref<4x5xf32, #spirv.storage_class<Workgroup>>
return
}
}
// CHECK: spirv.GlobalVariable @[[VAR:.+]] : !spirv.ptr<!spirv.array<20 x f32>, Workgroup>
// CHECK: func @alloc_dealloc_workgroup_mem
// CHECK-NOT: memref.alloc
// CHECK: %[[PTR:.+]] = spirv.mlir.addressof @[[VAR]]
// CHECK: %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]]
// CHECK: %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : f32
// CHECK: %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
// CHECK: spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32
// CHECK-NOT: memref.dealloc

View File

@ -121,16 +121,16 @@ module attributes {
// CHECK-LABEL: @load_store_zero_rank_float
func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<f32, #spirv.storage_class<CrossWorkgroup>>) {
// CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
// CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
// CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup>
// CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup>
// CHECK: [[ZERO1:%.*]] = spirv.Constant 0 : i32
// CHECK: spirv.PtrAccessChain [[ARG0]][
// CHECK: spirv.AccessChain [[ARG0]][
// CHECK-SAME: [[ZERO1]]
// CHECK-SAME: ] :
// CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : f32
%0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<CrossWorkgroup>>
// CHECK: [[ZERO2:%.*]] = spirv.Constant 0 : i32
// CHECK: spirv.PtrAccessChain [[ARG1]][
// CHECK: spirv.AccessChain [[ARG1]][
// CHECK-SAME: [[ZERO2]]
// CHECK-SAME: ] :
// CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : f32
@ -140,16 +140,16 @@ func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<Cr
// CHECK-LABEL: @load_store_zero_rank_int
func.func @load_store_zero_rank_int(%arg0: memref<i32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<i32, #spirv.storage_class<CrossWorkgroup>>) {
// CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i32, CrossWorkgroup>
// CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i32, CrossWorkgroup>
// CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup>
// CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup>
// CHECK: [[ZERO1:%.*]] = spirv.Constant 0 : i32
// CHECK: spirv.PtrAccessChain [[ARG0]][
// CHECK: spirv.AccessChain [[ARG0]][
// CHECK-SAME: [[ZERO1]]
// CHECK-SAME: ] :
// CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : i32
%0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<CrossWorkgroup>>
// CHECK: [[ZERO2:%.*]] = spirv.Constant 0 : i32
// CHECK: spirv.PtrAccessChain [[ARG1]][
// CHECK: spirv.AccessChain [[ARG1]][
// CHECK-SAME: [[ZERO2]]
// CHECK-SAME: ] :
// CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : i32
@ -173,14 +173,13 @@ func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.stora
// CHECK-LABEL: func @load_i1
// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %[[IDX:.+]]: index)
func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i : index) -> i1 {
// CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i8, CrossWorkgroup>
// CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
// CHECK: %[[ZERO_0:.+]] = spirv.Constant 0 : i32
// CHECK: %[[ZERO_1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
// CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
// CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_1]], %[[MUL]] : i32
// CHECK: %[[ADDR:.+]] = spirv.PtrAccessChain %[[SRC_CAST]][%[[ADD]]]
// CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_0]], %[[MUL]] : i32
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ADD]]]
// CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
// CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
@ -194,14 +193,13 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
// CHECK-SAME: %[[IDX:.+]]: index
func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i: index) {
%true = arith.constant true
// CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i8, CrossWorkgroup>
// CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
// CHECK: %[[ZERO_0:.+]] = spirv.Constant 0 : i32
// CHECK: %[[ZERO_1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
// CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
// CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_1]], %[[MUL]] : i32
// CHECK: %[[ADDR:.+]] = spirv.PtrAccessChain %[[DST_CAST]][%[[ADD]]]
// CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_0]], %[[MUL]] : i32
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ADD]]]
// CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
// CHECK: %[[RES:.+]] = spirv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8