[mlir][GPUToVulkan] Fix signature of bindMemRef function for f16

Binding MemRefs of f16 needs special handling as the type is not supported on
CPU. There was a bug in the type used.

Differential Revision: https://reviews.llvm.org/D86328
This commit is contained in:
Thomas Raoux 2020-08-21 10:34:12 -07:00
parent 08249d7f72
commit 36ee9a322a
2 changed files with 3 additions and 1 deletions

View File

@ -328,7 +328,7 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
std::string fnName = "bindMemRef" + std::to_string(i) + "D" + std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
std::string(stringifyType(type)); std::string(stringifyType(type));
if (type.isHalfTy()) if (type.isHalfTy())
type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(&getContext())); type = LLVM::LLVMType::getInt16Ty(&getContext());
if (!module.lookupSymbol(fnName)) { if (!module.lookupSymbol(fnName)) {
auto fnType = LLVM::LLVMType::getFunctionTy( auto fnType = LLVM::LLVMType::getFunctionTy(
getVoidType(), getVoidType(),

View File

@ -15,6 +15,8 @@
// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void // CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void
// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void // CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void
// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, !llvm.i32, !llvm.i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
module attributes {gpu.container_module} { module attributes {gpu.container_module} {
llvm.func @malloc(!llvm.i64) -> !llvm.ptr<i8> llvm.func @malloc(!llvm.i64) -> !llvm.ptr<i8>
llvm.func @foo() { llvm.func @foo() {