forked from OSchip/llvm-project
[mlir][vulkan-runner] Add support for 3D memrefs.
Summary: Add support for 3D memrefs in mlir-vulkan-runner and simple test. Differential Revision: https://reviews.llvm.org/D77157
This commit is contained in:
parent
6aecf0cfef
commit
0718e3ae31
|
@ -60,7 +60,7 @@ private:
|
||||||
// TODO(denis0x0D): Handle other types.
|
// TODO(denis0x0D): Handle other types.
|
||||||
if (auto memRefType = type.dyn_cast_or_null<MemRefType>())
|
if (auto memRefType = type.dyn_cast_or_null<MemRefType>())
|
||||||
return memRefType.hasRank() &&
|
return memRefType.hasRank() &&
|
||||||
(memRefType.getRank() == 1 || memRefType.getRank() == 2);
|
(memRefType.getRank() >= 1 && memRefType.getRank() <= 3);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ using namespace mlir;
|
||||||
|
|
||||||
static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
|
static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
|
||||||
static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
|
static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
|
||||||
|
static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat";
|
||||||
static constexpr const char *kCInterfaceVulkanLaunch =
|
static constexpr const char *kCInterfaceVulkanLaunch =
|
||||||
"_mlir_ciface_vulkanLaunch";
|
"_mlir_ciface_vulkanLaunch";
|
||||||
static constexpr const char *kDeinitVulkan = "deinitVulkan";
|
static constexpr const char *kDeinitVulkan = "deinitVulkan";
|
||||||
|
@ -76,10 +77,12 @@ private:
|
||||||
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
|
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
|
||||||
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
initializeMemRefTypes();
|
llvmMemRef1DFloat = getMemRefType(1);
|
||||||
|
llvmMemRef2DFloat = getMemRefType(2);
|
||||||
|
llvmMemRef3DFloat = getMemRefType(3);
|
||||||
}
|
}
|
||||||
|
|
||||||
void initializeMemRefTypes() {
|
LLVM::LLVMType getMemRefType(uint32_t rank) {
|
||||||
// According to the MLIR doc memref argument is converted into a
|
// According to the MLIR doc memref argument is converted into a
|
||||||
// pointer-to-struct argument of type:
|
// pointer-to-struct argument of type:
|
||||||
// template <typename Elem, size_t Rank>
|
// template <typename Elem, size_t Rank>
|
||||||
|
@ -91,22 +94,15 @@ private:
|
||||||
// int64_t strides[Rank]; // omitted when rank == 0
|
// int64_t strides[Rank]; // omitted when rank == 0
|
||||||
// };
|
// };
|
||||||
auto llvmPtrToFloatType = getFloatType().getPointerTo();
|
auto llvmPtrToFloatType = getFloatType().getPointerTo();
|
||||||
auto llvmArrayOneElementSizeType =
|
auto llvmArrayRankElementSizeType =
|
||||||
LLVM::LLVMType::getArrayTy(getInt64Type(), 1);
|
LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
|
||||||
auto llvmArrayTwoElementSizeType =
|
|
||||||
LLVM::LLVMType::getArrayTy(getInt64Type(), 2);
|
|
||||||
|
|
||||||
// Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`.
|
// Create a type
|
||||||
llvmMemRef1DFloat = LLVM::LLVMType::getStructTy(
|
// `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`.
|
||||||
|
return LLVM::LLVMType::getStructTy(
|
||||||
llvmDialect,
|
llvmDialect,
|
||||||
{llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
|
{llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
|
||||||
llvmArrayOneElementSizeType, llvmArrayOneElementSizeType});
|
llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
|
||||||
|
|
||||||
// Create a type `!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64]}">`.
|
|
||||||
llvmMemRef2DFloat = LLVM::LLVMType::getStructTy(
|
|
||||||
llvmDialect,
|
|
||||||
{llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
|
|
||||||
llvmArrayTwoElementSizeType, llvmArrayTwoElementSizeType});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVM::LLVMType getFloatType() { return llvmFloatType; }
|
LLVM::LLVMType getFloatType() { return llvmFloatType; }
|
||||||
|
@ -116,6 +112,7 @@ private:
|
||||||
LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
|
LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
|
||||||
LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
|
LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
|
||||||
LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
|
LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
|
||||||
|
LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; }
|
||||||
|
|
||||||
/// Creates a LLVM global for the given `name`.
|
/// Creates a LLVM global for the given `name`.
|
||||||
Value createEntryPointNameConstant(StringRef name, Location loc,
|
Value createEntryPointNameConstant(StringRef name, Location loc,
|
||||||
|
@ -164,6 +161,7 @@ private:
|
||||||
LLVM::LLVMType llvmInt64Type;
|
LLVM::LLVMType llvmInt64Type;
|
||||||
LLVM::LLVMType llvmMemRef1DFloat;
|
LLVM::LLVMType llvmMemRef1DFloat;
|
||||||
LLVM::LLVMType llvmMemRef2DFloat;
|
LLVM::LLVMType llvmMemRef2DFloat;
|
||||||
|
LLVM::LLVMType llvmMemRef3DFloat;
|
||||||
|
|
||||||
// TODO: Use an associative array to support multiple vulkan launch calls.
|
// TODO: Use an associative array to support multiple vulkan launch calls.
|
||||||
std::pair<StringAttr, StringAttr> spirvAttributes;
|
std::pair<StringAttr, StringAttr> spirvAttributes;
|
||||||
|
@ -335,6 +333,16 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
|
||||||
/*isVarArg=*/false));
|
/*isVarArg=*/false));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!module.lookupSymbol(kBindMemRef3DFloat)) {
|
||||||
|
builder.create<LLVM::LLVMFuncOp>(
|
||||||
|
loc, kBindMemRef3DFloat,
|
||||||
|
LLVM::LLVMType::getFunctionTy(getVoidType(),
|
||||||
|
{getPointerType(), getInt32Type(),
|
||||||
|
getInt32Type(),
|
||||||
|
getMemRef3DFloat().getPointerTo()},
|
||||||
|
/*isVarArg=*/false));
|
||||||
|
}
|
||||||
|
|
||||||
if (!module.lookupSymbol(kInitVulkan)) {
|
if (!module.lookupSymbol(kInitVulkan)) {
|
||||||
builder.create<LLVM::LLVMFuncOp>(
|
builder.create<LLVM::LLVMFuncOp>(
|
||||||
loc, kInitVulkan,
|
loc, kInitVulkan,
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
// RUN: mlir-vulkan-runner %s --shared-libs=%vulkan_wrapper_library_dir/libvulkan-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-COUNT-32: [2.2, 2.2, 2.2, 2.2]
|
||||||
|
module attributes {
|
||||||
|
gpu.container_module,
|
||||||
|
spv.target_env = #spv.target_env<
|
||||||
|
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
|
||||||
|
{max_compute_workgroup_invocations = 128 : i32,
|
||||||
|
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||||
|
} {
|
||||||
|
gpu.module @kernels {
|
||||||
|
gpu.func @kernel_sub(%arg0 : memref<8x4x4xf32>, %arg1 : memref<4x4xf32>, %arg2 : memref<8x4x4xf32>)
|
||||||
|
attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} {
|
||||||
|
%x = "gpu.block_id"() {dimension = "x"} : () -> index
|
||||||
|
%y = "gpu.block_id"() {dimension = "y"} : () -> index
|
||||||
|
%z = "gpu.block_id"() {dimension = "z"} : () -> index
|
||||||
|
%1 = load %arg0[%x, %y, %z] : memref<8x4x4xf32>
|
||||||
|
%2 = load %arg1[%y, %z] : memref<4x4xf32>
|
||||||
|
%3 = subf %1, %2 : f32
|
||||||
|
store %3, %arg2[%x, %y, %z] : memref<8x4x4xf32>
|
||||||
|
gpu.return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func @main() {
|
||||||
|
%arg0 = alloc() : memref<8x4x4xf32>
|
||||||
|
%arg1 = alloc() : memref<4x4xf32>
|
||||||
|
%arg2 = alloc() : memref<8x4x4xf32>
|
||||||
|
%0 = constant 0 : i32
|
||||||
|
%1 = constant 1 : i32
|
||||||
|
%2 = constant 2 : i32
|
||||||
|
%value0 = constant 0.0 : f32
|
||||||
|
%value1 = constant 3.3 : f32
|
||||||
|
%value2 = constant 1.1 : f32
|
||||||
|
%arg3 = memref_cast %arg0 : memref<8x4x4xf32> to memref<?x?x?xf32>
|
||||||
|
%arg4 = memref_cast %arg1 : memref<4x4xf32> to memref<?x?xf32>
|
||||||
|
%arg5 = memref_cast %arg2 : memref<8x4x4xf32> to memref<?x?x?xf32>
|
||||||
|
call @fillResource3DFloat(%arg3, %value1) : (memref<?x?x?xf32>, f32) -> ()
|
||||||
|
call @fillResource2DFloat(%arg4, %value2) : (memref<?x?xf32>, f32) -> ()
|
||||||
|
call @fillResource3DFloat(%arg5, %value0) : (memref<?x?x?xf32>, f32) -> ()
|
||||||
|
|
||||||
|
%cst1 = constant 1 : index
|
||||||
|
%cst4 = constant 4 : index
|
||||||
|
%cst8 = constant 8 : index
|
||||||
|
"gpu.launch_func"(%cst8, %cst4, %cst4, %cst1, %cst1, %cst1, %arg0, %arg1, %arg2) { kernel = "kernel_sub", kernel_module = @kernels }
|
||||||
|
: (index, index, index, index, index, index, memref<8x4x4xf32>, memref<4x4xf32>, memref<8x4x4xf32>) -> ()
|
||||||
|
%arg6 = memref_cast %arg5 : memref<?x?x?xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%arg6) : (memref<*xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
func @fillResource2DFloat(%0 : memref<?x?xf32>, %1 : f32)
|
||||||
|
func @fillResource3DFloat(%0 : memref<?x?x?xf32>, %1 : f32)
|
||||||
|
func @print_memref_f32(%ptr : memref<*xf32>)
|
||||||
|
}
|
|
@ -123,6 +123,18 @@ void bindMemRef2DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex,
|
||||||
->setResourceData(setIndex, bindIndex, memBuffer);
|
->setResourceData(setIndex, bindIndex, memBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Binds the given 3D float memref to the given descriptor set and descriptor
|
||||||
|
/// index.
|
||||||
|
void bindMemRef3DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex,
|
||||||
|
BindingIndex bindIndex,
|
||||||
|
MemRefDescriptor<float, 3> *ptr) {
|
||||||
|
VulkanHostMemoryBuffer memBuffer{
|
||||||
|
ptr->allocated, static_cast<uint32_t>(ptr->sizes[0] * ptr->sizes[1] *
|
||||||
|
ptr->sizes[2] * sizeof(float))};
|
||||||
|
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
|
||||||
|
->setResourceData(setIndex, bindIndex, memBuffer);
|
||||||
|
}
|
||||||
|
|
||||||
/// Fills the given 1D float memref with the given float value.
|
/// Fills the given 1D float memref with the given float value.
|
||||||
void _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
|
void _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
|
||||||
float value) {
|
float value) {
|
||||||
|
@ -134,4 +146,11 @@ void _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
|
||||||
float value) {
|
float value) {
|
||||||
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
|
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fills the given 3D float memref with the given float value.
|
||||||
|
void _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
|
||||||
|
float value) {
|
||||||
|
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
|
||||||
|
value);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue