diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index d9b29e0e306c..107701303099 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -22,6 +22,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 1b547bb89ef1..84f2624f3f5b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -580,7 +581,7 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc", def LLVM_InvokeOp : LLVM_Op<"invoke", [ AttrSizedOperandSegments, DeclareOpInterfaceMethods, - Terminator]> { + DeclareOpInterfaceMethods, Terminator]> { let arguments = (ins OptionalAttr:$callee, Variadic:$callee_operands, Variadic:$normalDestOperands, @@ -616,7 +617,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> { } def LLVM_CallOp : LLVM_Op<"call", - [DeclareOpInterfaceMethods]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Call to an LLVM function."; let description = [{ @@ -1322,7 +1324,8 @@ def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [ } def LLVM_LLVMFuncOp : LLVM_Op<"func", [ - AutomaticAllocationScope, IsolatedFromAbove, FunctionOpInterface, Symbol + AutomaticAllocationScope, IsolatedFromAbove, FunctionOpInterface, + CallableOpInterface, Symbol ]> { let summary = "LLVM dialect function."; @@ -1389,6 +1392,13 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ /// Returns the result types of this function. ArrayRef getResultTypes() { return getFunctionType().getReturnTypes(); } + + /// Returns the callable region, which is the function body. + Region *getCallableRegion() { return &getBody(); } + + /// Returns the callable result type, which is the function return type. + ArrayRef getCallableResults() { return getFunctionType().getReturnType(); } + }]; let hasCustomAssemblyFormat = 1; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 8bc63fc71fe8..d257244ba2ab 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -898,6 +898,18 @@ SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { : getUnwindDestOperandsMutable()); } +CallInterfaceCallable InvokeOp::getCallableForCallee() { + // Direct call. + if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) + return calleeAttr; + // Indirect call, callee Value is the first operand. + return getOperand(0); +} + +Operation::operand_range InvokeOp::getArgOperands() { + return getOperands().drop_front(getCallee().hasValue() ? 0 : 1); +} + LogicalResult InvokeOp::verify() { if (getNumResults() > 1) return emitOpError("must have 0 or 1 result"); @@ -1125,6 +1137,18 @@ ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { // Verifying/Printing/parsing for LLVM::CallOp. //===----------------------------------------------------------------------===// +CallInterfaceCallable CallOp::getCallableForCallee() { + // Direct call. + if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) + return calleeAttr; + // Indirect call, callee Value is the first operand. + return getOperand(0); +} + +Operation::operand_range CallOp::getArgOperands() { + return getOperands().drop_front(getCallee().hasValue() ? 0 : 1); +} + LogicalResult CallOp::verify() { if (getNumResults() > 1) return emitOpError("must have 0 or 1 result"); diff --git a/mlir/test/Dialect/LLVMIR/callgraph.mlir b/mlir/test/Dialect/LLVMIR/callgraph.mlir new file mode 100644 index 000000000000..268bcdfd053e --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/callgraph.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-opt --test-print-callgraph --split-input-file %s 2>&1 | FileCheck %s + +// CHECK: Testing : "Normal function call" +module attributes {"test.name" = "Normal function call"} { + // CHECK-LABEL: ---- CallGraph ---- + // CHECK: - Node : 'llvm.func' {{.*}} sym_name = "foo" + // CHECK-NEXT: -- Call-Edge : 'llvm.func' {{.*}} sym_name = "bar" + + // CHECK: - Node : 'llvm.func' {{.*}} sym_name = "bar" + // CHECK-NEXT: -- Call-Edge : 'llvm.func' {{.*}} sym_name = "foo" + + // CHECK: - Node : 'llvm.func' {{.*}} sym_name = "entry" + // CHECK-DAG: -- Call-Edge : 'llvm.func' {{.*}} sym_name = "foo" + // CHECK-DAG: -- Call-Edge : 'llvm.func' {{.*}} sym_name = "bar" + + // CHECK-LABEL: -- SCCs -- + // CHECK: - SCC : + // CHECK-DAG: -- Node :'llvm.func' {{.*}} sym_name = "foo" + // CHECK-DAG: -- Node :'llvm.func' {{.*}} sym_name = "bar" + + // CHECK: - SCC : + // CHECK-DAG: -- Node :'llvm.func' {{.*}} sym_name = "entry" + + llvm.func @foo(%arg0: i32) -> i32 { + %0 = llvm.mlir.constant(2 : i32) : i32 + %1 = llvm.sub %arg0, %0 : i32 + %2 = llvm.call @bar(%arg0, %1) : (i32, i32) -> i32 + llvm.return %2 : i32 + } + llvm.func @bar(%arg0: i32, %arg1: i32) -> i32 { + %0 = llvm.add %arg0, %arg1 : i32 + %1 = llvm.call @foo(%0) : (i32) -> i32 + llvm.return %1 : i32 + } + llvm.func @entry(%arg0: i32) -> i32 { + %0 = llvm.mlir.constant(2 : i32) : i32 + %1 = llvm.mlir.constant(0 : i32) : i32 + %2 = llvm.icmp "sgt" %arg0, %1 : i32 + llvm.cond_br %2, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + %3 = llvm.call @foo(%arg0) : (i32) -> i32 + llvm.br ^bb3(%3 : i32) + ^bb2: // pred: ^bb0 + %4 = llvm.add %arg0, %0 : i32 + %5 = llvm.call @bar(%arg0, %4) : (i32, i32) -> i32 + llvm.br ^bb3(%5 : i32) + ^bb3(%6: i32): // 2 preds: ^bb1, ^bb2 + llvm.return %6 : i32 + } +} + +// ----- + +// CHECK: Testing : "Invoke call" +module attributes {"test.name" = "Invoke call"} { + // CHECK-LABEL: ---- CallGraph ---- + // CHECK: - Node : 'llvm.func' {{.*}} sym_name = "invokeLandingpad" + // CHECK-DAG: -- Call-Edge : 'llvm.func' {{.*}} sym_name = "foo" + // CHECK-DAG: -- Call-Edge : 'llvm.func' {{.*}} sym_name = "bar" + + // CHECK: -- SCCs -- + llvm.mlir.global external constant @_ZTIi() : !llvm.ptr + llvm.func @foo(%arg0: i32) -> !llvm.struct<(i32, f64, i32)> + llvm.func @bar(!llvm.ptr, !llvm.ptr, !llvm.ptr) + llvm.func @__gxx_personality_v0(...) -> i32 + + llvm.func @invokeLandingpad() -> i32 attributes { personality = @__gxx_personality_v0 } { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(3 : i32) : i32 + %2 = llvm.mlir.constant("\01") : !llvm.array<1 x i8> + %3 = llvm.mlir.null : !llvm.ptr> + %4 = llvm.mlir.null : !llvm.ptr + %5 = llvm.mlir.addressof @_ZTIi : !llvm.ptr> + %6 = llvm.bitcast %5 : !llvm.ptr> to !llvm.ptr + %7 = llvm.mlir.constant(1 : i32) : i32 + %8 = llvm.alloca %7 x i8 : (i32) -> !llvm.ptr + %9 = llvm.invoke @foo(%7) to ^bb2 unwind ^bb1 : (i32) -> !llvm.struct<(i32, f64, i32)> + + ^bb1: + %10 = llvm.landingpad cleanup (catch %3 : !llvm.ptr>) (catch %6 : !llvm.ptr) (filter %2 : !llvm.array<1 x i8>) : !llvm.struct<(ptr, i32)> + %11 = llvm.intr.eh.typeid.for %6 : i32 + llvm.resume %10 : !llvm.struct<(ptr, i32)> + + ^bb2: + llvm.return %7 : i32 + + ^bb3: + llvm.invoke @bar(%8, %6, %4) to ^bb2 unwind ^bb1 : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> () + + ^bb4: + llvm.return %0 : i32 + } +}