diff --git a/mlir/test/Transforms/test-convert-call-op.mlir b/mlir/test/Transforms/test-convert-call-op.mlir new file mode 100644 index 000000000000..d1a1d5c35812 --- /dev/null +++ b/mlir/test/Transforms/test-convert-call-op.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -test-convert-call-op | FileCheck %s + +// CHECK-LABEL: llvm.func @callee(!llvm.ptr) -> !llvm.i32 +func @callee(!test.test_type) -> i32 + +// CHECK-NEXT: llvm.func @caller() -> !llvm.i32 +func @caller() -> i32 { + %arg = "test.type_producer"() : () -> !test.test_type + %out = call @callee(%arg) : (!test.test_type) -> i32 + return %out : i32 +} +// CHECK-NEXT: [[ARG:%.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: [[OUT:%.*]] = llvm.call @callee([[ARG]]) +// CHECK-SAME: : (!llvm.ptr) -> !llvm.i32 diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index c3318316c508..de894467d63d 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_library(MLIRTestTransforms TestExpandTanh.cpp TestCallGraph.cpp TestConstantFold.cpp + TestConvertCallOp.cpp TestConvertGPUKernelToCubin.cpp TestConvertGPUKernelToHsaco.cpp TestDominance.cpp diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp new file mode 100644 index 000000000000..6cb596bfc71a --- /dev/null +++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp @@ -0,0 +1,72 @@ +//===- TestConvertCallOp.cpp - Test LLVM Convesion of Standard CallOp -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "TestTypes.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +class TestTypeProducerOpConverter + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, getVoidPtrType()); + return success(); + } +}; + +class TestConvertCallOp + : public PassWrapper> { +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + + // Populate type conversions. + LLVMTypeConverter type_converter(m.getContext()); + type_converter.addConversion([&](TestType type) { + return LLVM::LLVMType::getInt8PtrTy(m.getContext()); + }); + + // Populate patterns. + OwningRewritePatternList patterns; + populateStdToLLVMConversionPatterns(type_converter, patterns); + patterns.insert(type_converter); + + // Set target. + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addIllegalDialect(); + + if (failed(applyPartialConversion(m, target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +namespace mlir { +void registerConvertCallOpPass() { + PassRegistration( + "test-convert-call-op", + "Tests conversion of `std.call` to `llvm.call` in " + "presence of custom types"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 3be470d4e3de..efcb32856607 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -29,6 +29,7 @@ using namespace mlir; namespace mlir { // Defined in the test directory, no public header. +void registerConvertCallOpPass(); void registerConvertToTargetEnvPass(); void registerInliner(); void registerMemRefBoundCheck(); @@ -102,6 +103,7 @@ static cl::opt allowUnregisteredDialects( #ifdef MLIR_INCLUDE_TESTS void registerTestPasses() { + registerConvertCallOpPass(); registerConvertToTargetEnvPass(); registerInliner(); registerMemRefBoundCheck();