forked from OSchip/llvm-project
LLVM dialect conversion and target: support indirect calls
Add support for converting MLIR `call_indirect` instructions to the LLVM IR dialect. In LLVM IR, the same instruction is used for direct and indirect calls. In the dialect, we have `llvm.call` and `llvm.call0` to work around the absence of the void type in MLIR. For direct calls, the callee is stored as instruction attribute. Use the same pair of instructions for indirect calls by omitting the callee attribute. In the MLIR to LLVM IR translator, check the presence of attribute to decide whether to construct a direct or an indirect call using different LLVM IR Builder functions. Add support for converting constants of function type to the LLVM IR dialect and for translating them to the LLVM IR proper. The `llvm.constant` operation works similarly to other types: its attribute has MLIR function type but the value it produces has LLVM IR function type wrapped in the dialect type. While lowering, look up the pointer to the converted function in the corresponding mapping. PiperOrigin-RevId: 234132351
This commit is contained in:
parent
d7aa700ccb
commit
ffc9043604
|
@ -427,6 +427,9 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
? voidType
|
||||
: TypeConverter::pack(getTypes(op->getResults()),
|
||||
this->dialect.getLLVMModule(), *mlirContext);
|
||||
assert(
|
||||
packedType &&
|
||||
"type conversion failed, such operation should not have been matched");
|
||||
|
||||
auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
|
||||
op->getAttrs());
|
||||
|
@ -497,27 +500,49 @@ struct SelectOpLowering
|
|||
// Refine the matcher for call operations that return one result or more.
|
||||
// Since tablegen'ed MLIR Ops cannot have variadic results, we separate calls
|
||||
// that have 0 or 1 result (LLVM calls cannot have more than 1).
|
||||
struct CallOpLowering : public OneToOneLLVMOpLowering<CallOp, LLVM::CallOp> {
|
||||
using Super::Super;
|
||||
template <typename SourceOp>
|
||||
struct NonZeroResultCallLowering
|
||||
: public OneToOneLLVMOpLowering<SourceOp, LLVM::CallOp> {
|
||||
using OneToOneLLVMOpLowering<SourceOp, LLVM::CallOp>::OneToOneLLVMOpLowering;
|
||||
using Super = NonZeroResultCallLowering<SourceOp>;
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
if (op->getNumResults() > 0)
|
||||
return Super::match(op);
|
||||
return matchFailure();
|
||||
return OneToOneLLVMOpLowering<SourceOp, LLVM::CallOp>::match(op);
|
||||
return this->matchFailure();
|
||||
}
|
||||
};
|
||||
|
||||
// Refine the matcher for call operations that return zero results.
|
||||
// Since tablegen'ed MLIR Ops cannot have variadic results, we separate calls
|
||||
// that have 0 or 1 result (LLVM calls cannot have more than 1).
|
||||
struct Call0OpLowering : public OneToOneLLVMOpLowering<CallOp, LLVM::Call0Op> {
|
||||
using Super::Super;
|
||||
template <typename SourceOp>
|
||||
struct ZeroResultCallLowering
|
||||
: public OneToOneLLVMOpLowering<SourceOp, LLVM::Call0Op> {
|
||||
using OneToOneLLVMOpLowering<SourceOp, LLVM::Call0Op>::OneToOneLLVMOpLowering;
|
||||
using Super = ZeroResultCallLowering<SourceOp>;
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
if (op->getNumResults() == 0)
|
||||
return Super::match(op);
|
||||
return matchFailure();
|
||||
return OneToOneLLVMOpLowering<SourceOp, LLVM::Call0Op>::match(op);
|
||||
return this->matchFailure();
|
||||
}
|
||||
};
|
||||
|
||||
struct Call0OpLowering : public ZeroResultCallLowering<CallOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct CallOpLowering : public NonZeroResultCallLowering<CallOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct CallIndirect0OpLowering : public ZeroResultCallLowering<CallIndirectOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct CallIndirectOpLowering
|
||||
: public NonZeroResultCallLowering<CallIndirectOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
|
||||
struct ConstLLVMOpLowering
|
||||
: public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> {
|
||||
using Super::Super;
|
||||
|
@ -1021,7 +1046,8 @@ protected:
|
|||
// FIXME: this should be tablegen'ed
|
||||
return ConversionListBuilder<
|
||||
AddFOpLowering, AddIOpLowering, AllocOpLowering, BranchOpLowering,
|
||||
Call0OpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering,
|
||||
Call0OpLowering, CallIndirect0OpLowering, CallIndirectOpLowering,
|
||||
CallOpLowering, CmpIOpLowering, CondBranchOpLowering,
|
||||
ConstLLVMOpLowering, DeallocOpLowering, DimOpLowering, DivISOpLowering,
|
||||
DivIUOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering,
|
||||
MulIOpLowering, RemISOpLowering, RemIUOpLowering, ReturnOpLowering,
|
||||
|
|
|
@ -61,6 +61,9 @@ private:
|
|||
bool convertBlock(const Block &bb, bool ignoreArguments);
|
||||
bool convertInstruction(const Instruction &inst, llvm::IRBuilder<> &builder);
|
||||
|
||||
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
|
||||
Location loc);
|
||||
|
||||
// Original and translated module.
|
||||
const Module &mlirModule;
|
||||
std::unique_ptr<llvm::Module> llvmModule;
|
||||
|
@ -112,16 +115,19 @@ static llvm::FunctionType *convertFunctionType(llvm::LLVMContext &llvmContext,
|
|||
// This currently supports integer, floating point, splat and dense element
|
||||
// attributes and combinations thereof. In case of error, report it to `loc`
|
||||
// and return nullptr.
|
||||
static llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
|
||||
MLIRContext &context, Location loc) {
|
||||
llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
|
||||
Attribute attr,
|
||||
Location loc) {
|
||||
if (auto intAttr = attr.dyn_cast<IntegerAttr>())
|
||||
return llvm::ConstantInt::get(llvmType, intAttr.getValue());
|
||||
if (auto floatAttr = attr.dyn_cast<FloatAttr>())
|
||||
return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
|
||||
if (auto funcAttr = attr.dyn_cast<FunctionAttr>())
|
||||
return functionMapping.lookup(funcAttr.getValue());
|
||||
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
||||
auto *vectorType = cast<llvm::VectorType>(llvmType);
|
||||
auto *child = getLLVMConstant(vectorType->getElementType(),
|
||||
splatAttr.getValue(), context, loc);
|
||||
splatAttr.getValue(), loc);
|
||||
return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
|
||||
}
|
||||
if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
|
||||
|
@ -133,13 +139,13 @@ static llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
|
|||
denseAttr.getValues(nested);
|
||||
for (auto n : nested) {
|
||||
constants.push_back(
|
||||
getLLVMConstant(vectorType->getElementType(), n, context, loc));
|
||||
getLLVMConstant(vectorType->getElementType(), n, loc));
|
||||
if (!constants.back())
|
||||
return nullptr;
|
||||
}
|
||||
return llvm::ConstantVector::get(constants);
|
||||
}
|
||||
context.emitError(loc, "unsupported constant value");
|
||||
mlirModule.getContext()->emitError(loc, "unsupported constant value");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -222,8 +228,8 @@ bool ModuleTranslation::convertInstruction(const Instruction &inst,
|
|||
if (auto op = inst.dyn_cast<LLVM::ConstantOp>()) {
|
||||
Attribute attr = op->getAttr("value");
|
||||
auto type = op->getResult()->getType().cast<LLVM::LLVMType>();
|
||||
valueMapping[op->getResult()] = getLLVMConstant(
|
||||
type.getUnderlyingType(), attr, *inst.getContext(), inst.getLoc());
|
||||
valueMapping[op->getResult()] =
|
||||
getLLVMConstant(type.getUnderlyingType(), attr, inst.getLoc());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -239,19 +245,31 @@ bool ModuleTranslation::convertInstruction(const Instruction &inst,
|
|||
return remapped;
|
||||
};
|
||||
|
||||
// Emit function calls. In addition to operands, we also need to look up the
|
||||
// remapped function itself.
|
||||
// Emit function calls. If the "callee" attribute is present, this is a
|
||||
// direct function call and we also need to look up the remapped function
|
||||
// itself. Otherwise, this is an indirect call and the callee is the first
|
||||
// operand, look it up as a normal value. Return the llvm::Value representing
|
||||
// the function result, which may be of llvm::VoidTy type.
|
||||
auto convertCall = [this, lookupValues,
|
||||
&builder](const Instruction &inst) -> llvm::Value * {
|
||||
auto operands = lookupValues(inst.getOperands());
|
||||
ArrayRef<llvm::Value *> operandsRef(operands);
|
||||
if (auto attr = inst.getAttrOfType<FunctionAttr>("callee")) {
|
||||
return builder.CreateCall(functionMapping.lookup(attr.getValue()),
|
||||
operandsRef);
|
||||
} else {
|
||||
return builder.CreateCall(operandsRef.front(), operandsRef.drop_front());
|
||||
}
|
||||
};
|
||||
|
||||
// Emit calls. If the called function has a result, remap the corresponding
|
||||
// value.
|
||||
if (auto op = inst.dyn_cast<LLVM::CallOp>()) {
|
||||
auto attr = op->getAttrOfType<FunctionAttr>("callee");
|
||||
valueMapping[op->getResult()] =
|
||||
builder.CreateCall(functionMapping.lookup(attr.getValue()),
|
||||
lookupValues(op->getOperands()));
|
||||
valueMapping[op->getResult()] = convertCall(inst);
|
||||
return false;
|
||||
}
|
||||
if (auto op = inst.dyn_cast<LLVM::Call0Op>()) {
|
||||
auto attr = op->getAttrOfType<FunctionAttr>("callee");
|
||||
builder.CreateCall(functionMapping.lookup(attr.getValue()),
|
||||
lookupValues(op->getOperands()));
|
||||
if (inst.isa<LLVM::Call0Op>()) {
|
||||
convertCall(inst);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -28,3 +28,25 @@ func @pass_through(%arg0: () -> ()) -> (() -> ()) {
|
|||
//CHECK-NEXT: "llvm.return"(%0) : (!llvm<"void ()*">) -> ()
|
||||
return %bbarg : () -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @body(!llvm<"i32">)
|
||||
func @body(i32)
|
||||
|
||||
// CHECK-LABEL: func @indirect_const_call(%arg0: !llvm<"i32">) {
|
||||
func @indirect_const_call(%arg0: i32) {
|
||||
// CHECK-NEXT: %0 = "llvm.constant"() {value: @body : (!llvm<"i32">) -> ()} : () -> !llvm<"void (i32)*">
|
||||
%0 = constant @body : (i32) -> ()
|
||||
// CHECK-NEXT: "llvm.call0"(%0, %arg0) : (!llvm<"void (i32)*">, !llvm<"i32">) -> ()
|
||||
call_indirect %0(%arg0) : (i32) -> ()
|
||||
// CHECK-NEXT: "llvm.return"() : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm<"float">) -> !llvm<"i32"> {
|
||||
func @indirect_call(%arg0: (f32) -> i32, %arg1: f32) -> i32 {
|
||||
// CHECK-NEXT: %0 = "llvm.call"(%arg0, %arg1) : (!llvm<"i32 (float)*">, !llvm<"float">) -> !llvm<"i32">
|
||||
%0 = call_indirect %arg0(%arg1) : (f32) -> i32
|
||||
// CHECK-NEXT: "llvm.return"(%0) : (!llvm<"i32">) -> ()
|
||||
return %0 : i32
|
||||
}
|
||||
|
||||
|
|
|
@ -735,3 +735,25 @@ func @ops(%arg0: !llvm<"float">, %arg1: !llvm<"float">, %arg2: !llvm<"i32">, %ar
|
|||
%10 = "llvm.insertvalue"(%9, %3) {position: [1]} : (!llvm<"{ float, i32 }">, !llvm<"i32">) -> !llvm<"{ float, i32 }">
|
||||
"llvm.return"(%10) : (!llvm<"{ float, i32 }">) -> ()
|
||||
}
|
||||
|
||||
//
|
||||
// Indirect function calls
|
||||
//
|
||||
|
||||
// CHECK-LABEL: define void @indirect_const_call(i64) {
|
||||
func @indirect_const_call(%arg0: !llvm<"i64">) {
|
||||
// CHECK-NEXT: call void @body(i64 %0)
|
||||
%0 = "llvm.constant"() {value: @body : (!llvm<"i64">) -> ()} : () -> !llvm<"void (i64)*">
|
||||
"llvm.call0"(%0, %arg0) : (!llvm<"void (i64)*">, !llvm<"i64">) -> ()
|
||||
// CHECK-NEXT: ret void
|
||||
"llvm.return"() : () -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define i32 @indirect_call(i32 (float)*, float) {
|
||||
func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm<"float">) -> !llvm<"i32"> {
|
||||
// CHECK-NEXT: %3 = call i32 %0(float %1)
|
||||
%0 = "llvm.call"(%arg0, %arg1) : (!llvm<"i32 (float)*">, !llvm<"float">) -> !llvm<"i32">
|
||||
// CHECK-NEXT: ret i32 %3
|
||||
"llvm.return"(%0) : (!llvm<"i32">) -> ()
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue