forked from OSchip/llvm-project
[mlir][spirv] Handle dynamic/static cases differntly for kernel capability
Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D134908
This commit is contained in:
parent
682c95672b
commit
d6de6dde82
|
@ -335,8 +335,10 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
|
|||
.getPointeeType();
|
||||
Type dstType;
|
||||
if (typeConverter.allows(spirv::Capability::Kernel)) {
|
||||
// For OpenCL Kernel, pointer will be directly pointing to the element.
|
||||
dstType = pointeeType;
|
||||
if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
|
||||
dstType = arrayType.getElementType();
|
||||
else
|
||||
dstType = pointeeType;
|
||||
} else {
|
||||
// For Vulkan we need to extract element from wrapping struct and array.
|
||||
Type structElemType =
|
||||
|
@ -464,8 +466,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
|
|||
.getPointeeType();
|
||||
Type dstType;
|
||||
if (typeConverter.allows(spirv::Capability::Kernel)) {
|
||||
// For OpenCL Kernel, pointer will be directly pointing to the element.
|
||||
dstType = pointeeType;
|
||||
if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
|
||||
dstType = arrayType.getElementType();
|
||||
else
|
||||
dstType = pointeeType;
|
||||
} else {
|
||||
// For Vulkan we need to extract element from wrapping struct and array.
|
||||
Type structElemType =
|
||||
|
|
|
@ -338,15 +338,16 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// For OpenCL Kernel we can just emit a pointer pointing to the element.
|
||||
if (targetEnv.allows(spirv::Capability::Kernel))
|
||||
return spirv::PointerType::get(arrayElemType, storageClass);
|
||||
|
||||
// For Vulkan we need extra wrapping struct and array to satisfy interface
|
||||
// needs.
|
||||
if (!type.hasStaticShape()) {
|
||||
// For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
|
||||
// to the element.
|
||||
if (targetEnv.allows(spirv::Capability::Kernel))
|
||||
return spirv::PointerType::get(arrayElemType, storageClass);
|
||||
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
|
||||
auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
|
||||
// For Vulkan we need extra wrapping struct and array to satisfy interface
|
||||
// needs.
|
||||
return wrapInStructAndGetPointer(arrayType, storageClass);
|
||||
}
|
||||
|
||||
|
@ -354,7 +355,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
|
|||
auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
|
||||
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
|
||||
auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
|
||||
|
||||
if (targetEnv.allows(spirv::Capability::Kernel))
|
||||
return spirv::PointerType::get(arrayType, storageClass);
|
||||
return wrapInStructAndGetPointer(arrayType, storageClass);
|
||||
}
|
||||
|
||||
|
@ -403,15 +405,16 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// For OpenCL Kernel we can just emit a pointer pointing to the element.
|
||||
if (targetEnv.allows(spirv::Capability::Kernel))
|
||||
return spirv::PointerType::get(arrayElemType, storageClass);
|
||||
|
||||
// For Vulkan we need extra wrapping struct and array to satisfy interface
|
||||
// needs.
|
||||
if (!type.hasStaticShape()) {
|
||||
// For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
|
||||
// to the element.
|
||||
if (targetEnv.allows(spirv::Capability::Kernel))
|
||||
return spirv::PointerType::get(arrayElemType, storageClass);
|
||||
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
|
||||
auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
|
||||
// For Vulkan we need extra wrapping struct and array to satisfy interface
|
||||
// needs.
|
||||
return wrapInStructAndGetPointer(arrayType, storageClass);
|
||||
}
|
||||
|
||||
|
@ -425,7 +428,8 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
|
|||
auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
|
||||
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
|
||||
auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
|
||||
|
||||
if (targetEnv.allows(spirv::Capability::Kernel))
|
||||
return spirv::PointerType::get(arrayType, storageClass);
|
||||
return wrapInStructAndGetPointer(arrayType, storageClass);
|
||||
}
|
||||
|
||||
|
@ -776,15 +780,20 @@ Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
|
|||
auto indexType = typeConverter.getIndexType();
|
||||
|
||||
SmallVector<Value, 2> linearizedIndices;
|
||||
auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
|
||||
|
||||
Value linearIndex;
|
||||
if (baseType.getRank() == 0) {
|
||||
linearIndex = zero;
|
||||
linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
|
||||
} else {
|
||||
linearIndex =
|
||||
linearizeIndex(indices, strides, offset, indexType, loc, builder);
|
||||
}
|
||||
Type pointeeType =
|
||||
basePtr.getType().cast<spirv::PointerType>().getPointeeType();
|
||||
if (pointeeType.isa<spirv::ArrayType>()) {
|
||||
linearizedIndices.push_back(linearIndex);
|
||||
return builder.create<spirv::AccessChainOp>(loc, basePtr,
|
||||
linearizedIndices);
|
||||
}
|
||||
return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
|
||||
linearizedIndices);
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ module attributes {
|
|||
// CHECK: spirv.func
|
||||
// CHECK-SAME: {{%.*}}: f32
|
||||
// CHECK-NOT: spirv.interface_var_abi
|
||||
// CHECK-SAME: {{%.*}}: !spirv.ptr<f32, CrossWorkgroup>
|
||||
// CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.array<12 x f32>, CrossWorkgroup>
|
||||
// CHECK-NOT: spirv.interface_var_abi
|
||||
// CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
|
||||
gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spirv.storage_class<CrossWorkgroup>>) kernel
|
||||
|
|
|
@ -155,3 +155,27 @@ module attributes {
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
module attributes {
|
||||
spirv.target_env = #spirv.target_env<
|
||||
#spirv.vce<v1.0, [Kernel], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
|
||||
}
|
||||
{
|
||||
func.func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
|
||||
%0 = memref.alloc() : memref<4x5xf32, #spirv.storage_class<Workgroup>>
|
||||
%1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
|
||||
memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
|
||||
memref.dealloc %0 : memref<4x5xf32, #spirv.storage_class<Workgroup>>
|
||||
return
|
||||
}
|
||||
}
|
||||
// CHECK: spirv.GlobalVariable @[[VAR:.+]] : !spirv.ptr<!spirv.array<20 x f32>, Workgroup>
|
||||
// CHECK: func @alloc_dealloc_workgroup_mem
|
||||
// CHECK-NOT: memref.alloc
|
||||
// CHECK: %[[PTR:.+]] = spirv.mlir.addressof @[[VAR]]
|
||||
// CHECK: %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]]
|
||||
// CHECK: %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : f32
|
||||
// CHECK: %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
|
||||
// CHECK: spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32
|
||||
// CHECK-NOT: memref.dealloc
|
||||
|
|
|
@ -121,16 +121,16 @@ module attributes {
|
|||
|
||||
// CHECK-LABEL: @load_store_zero_rank_float
|
||||
func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<f32, #spirv.storage_class<CrossWorkgroup>>) {
|
||||
// CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
|
||||
// CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
|
||||
// CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup>
|
||||
// CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup>
|
||||
// CHECK: [[ZERO1:%.*]] = spirv.Constant 0 : i32
|
||||
// CHECK: spirv.PtrAccessChain [[ARG0]][
|
||||
// CHECK: spirv.AccessChain [[ARG0]][
|
||||
// CHECK-SAME: [[ZERO1]]
|
||||
// CHECK-SAME: ] :
|
||||
// CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : f32
|
||||
%0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<CrossWorkgroup>>
|
||||
// CHECK: [[ZERO2:%.*]] = spirv.Constant 0 : i32
|
||||
// CHECK: spirv.PtrAccessChain [[ARG1]][
|
||||
// CHECK: spirv.AccessChain [[ARG1]][
|
||||
// CHECK-SAME: [[ZERO2]]
|
||||
// CHECK-SAME: ] :
|
||||
// CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : f32
|
||||
|
@ -140,16 +140,16 @@ func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<Cr
|
|||
|
||||
// CHECK-LABEL: @load_store_zero_rank_int
|
||||
func.func @load_store_zero_rank_int(%arg0: memref<i32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<i32, #spirv.storage_class<CrossWorkgroup>>) {
|
||||
// CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i32, CrossWorkgroup>
|
||||
// CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i32, CrossWorkgroup>
|
||||
// CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup>
|
||||
// CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup>
|
||||
// CHECK: [[ZERO1:%.*]] = spirv.Constant 0 : i32
|
||||
// CHECK: spirv.PtrAccessChain [[ARG0]][
|
||||
// CHECK: spirv.AccessChain [[ARG0]][
|
||||
// CHECK-SAME: [[ZERO1]]
|
||||
// CHECK-SAME: ] :
|
||||
// CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : i32
|
||||
%0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<CrossWorkgroup>>
|
||||
// CHECK: [[ZERO2:%.*]] = spirv.Constant 0 : i32
|
||||
// CHECK: spirv.PtrAccessChain [[ARG1]][
|
||||
// CHECK: spirv.AccessChain [[ARG1]][
|
||||
// CHECK-SAME: [[ZERO2]]
|
||||
// CHECK-SAME: ] :
|
||||
// CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : i32
|
||||
|
@ -173,14 +173,13 @@ func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.stora
|
|||
// CHECK-LABEL: func @load_i1
|
||||
// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %[[IDX:.+]]: index)
|
||||
func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i : index) -> i1 {
|
||||
// CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i8, CrossWorkgroup>
|
||||
// CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
|
||||
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
|
||||
// CHECK: %[[ZERO_0:.+]] = spirv.Constant 0 : i32
|
||||
// CHECK: %[[ZERO_1:.+]] = spirv.Constant 0 : i32
|
||||
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
|
||||
// CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
|
||||
// CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_1]], %[[MUL]] : i32
|
||||
// CHECK: %[[ADDR:.+]] = spirv.PtrAccessChain %[[SRC_CAST]][%[[ADD]]]
|
||||
// CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_0]], %[[MUL]] : i32
|
||||
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ADD]]]
|
||||
// CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
|
||||
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
|
||||
// CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
|
||||
|
@ -194,14 +193,13 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
|
|||
// CHECK-SAME: %[[IDX:.+]]: index
|
||||
func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i: index) {
|
||||
%true = arith.constant true
|
||||
// CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i8, CrossWorkgroup>
|
||||
// CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
|
||||
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
|
||||
// CHECK: %[[ZERO_0:.+]] = spirv.Constant 0 : i32
|
||||
// CHECK: %[[ZERO_1:.+]] = spirv.Constant 0 : i32
|
||||
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
|
||||
// CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
|
||||
// CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_1]], %[[MUL]] : i32
|
||||
// CHECK: %[[ADDR:.+]] = spirv.PtrAccessChain %[[DST_CAST]][%[[ADD]]]
|
||||
// CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_0]], %[[MUL]] : i32
|
||||
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ADD]]]
|
||||
// CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
|
||||
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
|
||||
// CHECK: %[[RES:.+]] = spirv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
|
||||
|
|
Loading…
Reference in New Issue