LLVM dialect conversion and target: support indirect calls

Add support for converting MLIR `call_indirect` instructions to the LLVM IR
dialect.  In LLVM IR, the same instruction is used for direct and indirect
calls.  In the dialect, we have `llvm.call` and `llvm.call0` to work around the
absence of the void type in MLIR.  For direct calls, the callee is stored as
instruction attribute.  Use the same pair of instructions for indirect calls by
omitting the callee attribute.  In the MLIR to LLVM IR translator, check the
presence of attribute to decide whether to construct a direct or an indirect
call using different LLVM IR Builder functions.

Add support for converting constants of function type to the LLVM IR dialect
and for translating them to the LLVM IR proper.  The `llvm.constant` operation
works similarly to other types: its attribute has MLIR function type but the
value it produces has LLVM IR function type wrapped in the dialect type.  While
lowering, look up the pointer to the converted function in the corresponding
mapping.

PiperOrigin-RevId: 234132351
This commit is contained in:
Alex Zinenko 2019-02-15 06:25:30 -08:00 committed by jpienaar
parent d7aa700ccb
commit ffc9043604
4 changed files with 114 additions and 26 deletions

View File

@ -427,6 +427,9 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
? voidType
: TypeConverter::pack(getTypes(op->getResults()),
this->dialect.getLLVMModule(), *mlirContext);
assert(
packedType &&
"type conversion failed, such operation should not have been matched");
auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
op->getAttrs());
@ -497,27 +500,49 @@ struct SelectOpLowering
// Refine the matcher for call operations that return one result or more.
// Since tablegen'ed MLIR Ops cannot have variadic results, we separate calls
// that have 0 or 1 result (LLVM calls cannot have more than 1).
struct CallOpLowering : public OneToOneLLVMOpLowering<CallOp, LLVM::CallOp> {
using Super::Super;
template <typename SourceOp>
struct NonZeroResultCallLowering
: public OneToOneLLVMOpLowering<SourceOp, LLVM::CallOp> {
using OneToOneLLVMOpLowering<SourceOp, LLVM::CallOp>::OneToOneLLVMOpLowering;
using Super = NonZeroResultCallLowering<SourceOp>;
PatternMatchResult match(Instruction *op) const override {
if (op->getNumResults() > 0)
return Super::match(op);
return matchFailure();
return OneToOneLLVMOpLowering<SourceOp, LLVM::CallOp>::match(op);
return this->matchFailure();
}
};
// Refine the matcher for call operations that return zero results.
// Since tablegen'ed MLIR Ops cannot have variadic results, we separate calls
// that have 0 or 1 result (LLVM calls cannot have more than 1).
struct Call0OpLowering : public OneToOneLLVMOpLowering<CallOp, LLVM::Call0Op> {
using Super::Super;
template <typename SourceOp>
struct ZeroResultCallLowering
: public OneToOneLLVMOpLowering<SourceOp, LLVM::Call0Op> {
using OneToOneLLVMOpLowering<SourceOp, LLVM::Call0Op>::OneToOneLLVMOpLowering;
using Super = ZeroResultCallLowering<SourceOp>;
PatternMatchResult match(Instruction *op) const override {
if (op->getNumResults() == 0)
return Super::match(op);
return matchFailure();
return OneToOneLLVMOpLowering<SourceOp, LLVM::Call0Op>::match(op);
return this->matchFailure();
}
};
struct Call0OpLowering : public ZeroResultCallLowering<CallOp> {
using Super::Super;
};
struct CallOpLowering : public NonZeroResultCallLowering<CallOp> {
using Super::Super;
};
struct CallIndirect0OpLowering : public ZeroResultCallLowering<CallIndirectOp> {
using Super::Super;
};
struct CallIndirectOpLowering
: public NonZeroResultCallLowering<CallIndirectOp> {
using Super::Super;
};
struct ConstLLVMOpLowering
: public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> {
using Super::Super;
@ -1021,7 +1046,8 @@ protected:
// FIXME: this should be tablegen'ed
return ConversionListBuilder<
AddFOpLowering, AddIOpLowering, AllocOpLowering, BranchOpLowering,
Call0OpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering,
Call0OpLowering, CallIndirect0OpLowering, CallIndirectOpLowering,
CallOpLowering, CmpIOpLowering, CondBranchOpLowering,
ConstLLVMOpLowering, DeallocOpLowering, DimOpLowering, DivISOpLowering,
DivIUOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering,
MulIOpLowering, RemISOpLowering, RemIUOpLowering, ReturnOpLowering,

View File

@ -61,6 +61,9 @@ private:
bool convertBlock(const Block &bb, bool ignoreArguments);
bool convertInstruction(const Instruction &inst, llvm::IRBuilder<> &builder);
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
Location loc);
// Original and translated module.
const Module &mlirModule;
std::unique_ptr<llvm::Module> llvmModule;
@ -112,16 +115,19 @@ static llvm::FunctionType *convertFunctionType(llvm::LLVMContext &llvmContext,
// This currently supports integer, floating point, splat and dense element
// attributes and combinations thereof. In case of error, report it to `loc`
// and return nullptr.
static llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
MLIRContext &context, Location loc) {
llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
Attribute attr,
Location loc) {
if (auto intAttr = attr.dyn_cast<IntegerAttr>())
return llvm::ConstantInt::get(llvmType, intAttr.getValue());
if (auto floatAttr = attr.dyn_cast<FloatAttr>())
return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
if (auto funcAttr = attr.dyn_cast<FunctionAttr>())
return functionMapping.lookup(funcAttr.getValue());
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
auto *vectorType = cast<llvm::VectorType>(llvmType);
auto *child = getLLVMConstant(vectorType->getElementType(),
splatAttr.getValue(), context, loc);
splatAttr.getValue(), loc);
return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
}
if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
@ -133,13 +139,13 @@ static llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
denseAttr.getValues(nested);
for (auto n : nested) {
constants.push_back(
getLLVMConstant(vectorType->getElementType(), n, context, loc));
getLLVMConstant(vectorType->getElementType(), n, loc));
if (!constants.back())
return nullptr;
}
return llvm::ConstantVector::get(constants);
}
context.emitError(loc, "unsupported constant value");
mlirModule.getContext()->emitError(loc, "unsupported constant value");
return nullptr;
}
@ -222,8 +228,8 @@ bool ModuleTranslation::convertInstruction(const Instruction &inst,
if (auto op = inst.dyn_cast<LLVM::ConstantOp>()) {
Attribute attr = op->getAttr("value");
auto type = op->getResult()->getType().cast<LLVM::LLVMType>();
valueMapping[op->getResult()] = getLLVMConstant(
type.getUnderlyingType(), attr, *inst.getContext(), inst.getLoc());
valueMapping[op->getResult()] =
getLLVMConstant(type.getUnderlyingType(), attr, inst.getLoc());
return false;
}
@ -239,19 +245,31 @@ bool ModuleTranslation::convertInstruction(const Instruction &inst,
return remapped;
};
// Emit function calls. In addition to operands, we also need to look up the
// remapped function itself.
// Emit function calls. If the "callee" attribute is present, this is a
// direct function call and we also need to look up the remapped function
// itself. Otherwise, this is an indirect call and the callee is the first
// operand, look it up as a normal value. Return the llvm::Value representing
// the function result, which may be of llvm::VoidTy type.
auto convertCall = [this, lookupValues,
&builder](const Instruction &inst) -> llvm::Value * {
auto operands = lookupValues(inst.getOperands());
ArrayRef<llvm::Value *> operandsRef(operands);
if (auto attr = inst.getAttrOfType<FunctionAttr>("callee")) {
return builder.CreateCall(functionMapping.lookup(attr.getValue()),
operandsRef);
} else {
return builder.CreateCall(operandsRef.front(), operandsRef.drop_front());
}
};
// Emit calls. If the called function has a result, remap the corresponding
// value.
if (auto op = inst.dyn_cast<LLVM::CallOp>()) {
auto attr = op->getAttrOfType<FunctionAttr>("callee");
valueMapping[op->getResult()] =
builder.CreateCall(functionMapping.lookup(attr.getValue()),
lookupValues(op->getOperands()));
valueMapping[op->getResult()] = convertCall(inst);
return false;
}
if (auto op = inst.dyn_cast<LLVM::Call0Op>()) {
auto attr = op->getAttrOfType<FunctionAttr>("callee");
builder.CreateCall(functionMapping.lookup(attr.getValue()),
lookupValues(op->getOperands()));
if (inst.isa<LLVM::Call0Op>()) {
convertCall(inst);
return false;
}

View File

@ -28,3 +28,25 @@ func @pass_through(%arg0: () -> ()) -> (() -> ()) {
//CHECK-NEXT: "llvm.return"(%0) : (!llvm<"void ()*">) -> ()
return %bbarg : () -> ()
}
// CHECK-LABEL: func @body(!llvm<"i32">)
func @body(i32)
// CHECK-LABEL: func @indirect_const_call(%arg0: !llvm<"i32">) {
func @indirect_const_call(%arg0: i32) {
// CHECK-NEXT: %0 = "llvm.constant"() {value: @body : (!llvm<"i32">) -> ()} : () -> !llvm<"void (i32)*">
%0 = constant @body : (i32) -> ()
// CHECK-NEXT: "llvm.call0"(%0, %arg0) : (!llvm<"void (i32)*">, !llvm<"i32">) -> ()
call_indirect %0(%arg0) : (i32) -> ()
// CHECK-NEXT: "llvm.return"() : () -> ()
return
}
// CHECK-LABEL: func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm<"float">) -> !llvm<"i32"> {
func @indirect_call(%arg0: (f32) -> i32, %arg1: f32) -> i32 {
// CHECK-NEXT: %0 = "llvm.call"(%arg0, %arg1) : (!llvm<"i32 (float)*">, !llvm<"float">) -> !llvm<"i32">
%0 = call_indirect %arg0(%arg1) : (f32) -> i32
// CHECK-NEXT: "llvm.return"(%0) : (!llvm<"i32">) -> ()
return %0 : i32
}

View File

@ -735,3 +735,25 @@ func @ops(%arg0: !llvm<"float">, %arg1: !llvm<"float">, %arg2: !llvm<"i32">, %ar
%10 = "llvm.insertvalue"(%9, %3) {position: [1]} : (!llvm<"{ float, i32 }">, !llvm<"i32">) -> !llvm<"{ float, i32 }">
"llvm.return"(%10) : (!llvm<"{ float, i32 }">) -> ()
}
//
// Indirect function calls
//
// CHECK-LABEL: define void @indirect_const_call(i64) {
func @indirect_const_call(%arg0: !llvm<"i64">) {
// CHECK-NEXT: call void @body(i64 %0)
%0 = "llvm.constant"() {value: @body : (!llvm<"i64">) -> ()} : () -> !llvm<"void (i64)*">
"llvm.call0"(%0, %arg0) : (!llvm<"void (i64)*">, !llvm<"i64">) -> ()
// CHECK-NEXT: ret void
"llvm.return"() : () -> ()
}
// CHECK-LABEL: define i32 @indirect_call(i32 (float)*, float) {
func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm<"float">) -> !llvm<"i32"> {
// CHECK-NEXT: %3 = call i32 %0(float %1)
%0 = "llvm.call"(%arg0, %arg1) : (!llvm<"i32 (float)*">, !llvm<"float">) -> !llvm<"i32">
// CHECK-NEXT: ret i32 %3
"llvm.return"(%0) : (!llvm<"i32">) -> ()
}