forked from OSchip/llvm-project
[mlir] Add fastmath flags support to some LLVM dialect ops
Add fastmath enum, attributes to some llvm dialect ops, `FastmathFlagsInterface` op interface, and `translateModuleToLLVMIR` support. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D92485
This commit is contained in:
parent
a9a8caf2ce
commit
c1d58c2b00
|
@ -10,6 +10,8 @@ add_public_tablegen_target(MLIRLLVMOpsIncGen)
|
|||
|
||||
add_mlir_doc(LLVMOps -gen-op-doc LLVMOps Dialects/)
|
||||
|
||||
add_mlir_interface(LLVMOpsInterfaces)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
|
||||
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
|
||||
mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions)
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "llvm/IR/Type.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.h.inc"
|
||||
|
||||
namespace llvm {
|
||||
class Type;
|
||||
|
@ -46,8 +47,23 @@ class LLVMDialect;
|
|||
namespace detail {
|
||||
struct LLVMTypeStorage;
|
||||
struct LLVMDialectImpl;
|
||||
struct BitmaskEnumStorage;
|
||||
} // namespace detail
|
||||
|
||||
/// An attribute that specifies LLVM instruction fastmath flags.
|
||||
class FMFAttr : public Attribute::AttrBase<FMFAttr, Attribute,
|
||||
detail::BitmaskEnumStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static FMFAttr get(FastmathFlags flags, MLIRContext *context);
|
||||
|
||||
FastmathFlags getFlags() const;
|
||||
|
||||
void print(DialectAsmPrinter &p) const;
|
||||
static Attribute parse(DialectAsmParser &parser);
|
||||
};
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -14,10 +14,39 @@
|
|||
#define LLVMIR_OPS
|
||||
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
def FMFnnan : BitEnumAttrCase<"nnan", 0x1>;
|
||||
def FMFninf : BitEnumAttrCase<"ninf", 0x2>;
|
||||
def FMFnsz : BitEnumAttrCase<"nsz", 0x4>;
|
||||
def FMFarcp : BitEnumAttrCase<"arcp", 0x8>;
|
||||
def FMFcontract : BitEnumAttrCase<"contract", 0x10>;
|
||||
def FMFafn : BitEnumAttrCase<"afn", 0x20>;
|
||||
def FMFreassoc : BitEnumAttrCase<"reassoc", 0x40>;
|
||||
def FMFfast : BitEnumAttrCase<"fast", 0x80>;
|
||||
|
||||
def FastmathFlags : BitEnumAttr<
|
||||
"FastmathFlags",
|
||||
"LLVM fastmath flags",
|
||||
[FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast
|
||||
]> {
|
||||
let cppNamespace = "::mlir::LLVM";
|
||||
}
|
||||
|
||||
def LLVM_FMFAttr : DialectAttr<
|
||||
LLVM_Dialect,
|
||||
CPred<"$_self.isa<::mlir::LLVM::FMFAttr>()">,
|
||||
"LLVM fastmath flags"> {
|
||||
let storageType = "::mlir::LLVM::FMFAttr";
|
||||
let returnType = "::mlir::LLVM::FastmathFlags";
|
||||
let convertFromStorage = "$_self.getFlags()";
|
||||
let constBuilderCall =
|
||||
"::mlir::LLVM::FMFAttr::get($0, $_builder.getContext())";
|
||||
}
|
||||
|
||||
class LLVM_Builder<string builder> {
|
||||
string llvmBuilder = builder;
|
||||
}
|
||||
|
@ -77,29 +106,35 @@ class LLVM_ArithmeticOpBase<Type type, string mnemonic,
|
|||
LLVM_Op<mnemonic,
|
||||
!listconcat([NoSideEffect, SameOperandsAndResultType], traits)>,
|
||||
LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> {
|
||||
let arguments = (ins LLVM_ScalarOrVectorOf<type>:$lhs,
|
||||
dag commonArgs = (ins LLVM_ScalarOrVectorOf<type>:$lhs,
|
||||
LLVM_ScalarOrVectorOf<type>:$rhs);
|
||||
let results = (outs LLVM_ScalarOrVectorOf<type>:$res);
|
||||
let builders = [LLVM_OneResultOpBuilder];
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($res)";
|
||||
let assemblyFormat = "$lhs `,` $rhs custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
|
||||
}
|
||||
class LLVM_IntArithmeticOp<string mnemonic, string builderFunc,
|
||||
list<OpTrait> traits = []> :
|
||||
LLVM_ArithmeticOpBase<LLVM_AnyInteger, mnemonic, builderFunc, traits>;
|
||||
LLVM_ArithmeticOpBase<LLVM_AnyInteger, mnemonic, builderFunc, traits> {
|
||||
let arguments = commonArgs;
|
||||
}
|
||||
class LLVM_FloatArithmeticOp<string mnemonic, string builderFunc,
|
||||
list<OpTrait> traits = []> :
|
||||
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, builderFunc, traits>;
|
||||
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, builderFunc,
|
||||
!listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)> {
|
||||
dag fmfArg = (ins DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
|
||||
let arguments = !con(commonArgs, fmfArg);
|
||||
}
|
||||
|
||||
// Class for arithmetic unary operations.
|
||||
class LLVM_UnaryArithmeticOp<Type type, string mnemonic,
|
||||
class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
|
||||
string builderFunc, list<OpTrait> traits = []> :
|
||||
LLVM_Op<mnemonic,
|
||||
!listconcat([NoSideEffect, SameOperandsAndResultType], traits)>,
|
||||
!listconcat([NoSideEffect, SameOperandsAndResultType, DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)>,
|
||||
LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> {
|
||||
let arguments = (ins type:$operand);
|
||||
let arguments = (ins type:$operand, DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
|
||||
let results = (outs type:$res);
|
||||
let builders = [LLVM_OneResultOpBuilder];
|
||||
let assemblyFormat = "$operand attr-dict `:` type($res)";
|
||||
let assemblyFormat = "$operand custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
|
||||
}
|
||||
|
||||
// Integer binary operations.
|
||||
|
@ -185,20 +220,24 @@ def FCmpPredicate : I64EnumAttr<
|
|||
let cppNamespace = "::mlir::LLVM";
|
||||
}
|
||||
|
||||
// Other integer operations.
|
||||
def LLVM_FCmpOp : LLVM_Op<"fcmp", [NoSideEffect]> {
|
||||
// Other floating-point operations.
|
||||
def LLVM_FCmpOp : LLVM_Op<"fcmp", [
|
||||
NoSideEffect, DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
|
||||
let arguments = (ins FCmpPredicate:$predicate,
|
||||
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$lhs,
|
||||
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs);
|
||||
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs,
|
||||
DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
|
||||
let results = (outs LLVM_ScalarOrVectorOf<LLVM_i1>:$res);
|
||||
let llvmBuilder = [{
|
||||
$res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs),
|
||||
OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs,
|
||||
CArg<"FastmathFlags", "{}">:$fmf),
|
||||
[{
|
||||
build($_builder, $_state, LLVMIntegerType::get(lhs.getType().getContext(), 1),
|
||||
$_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
|
||||
$_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs,
|
||||
::mlir::LLVM::FMFAttr::get(fmf, $_builder.getContext()));
|
||||
}]>];
|
||||
let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
|
||||
let printer = [{ printFCmpOp(p, *this); }];
|
||||
|
@ -210,8 +249,8 @@ def LLVM_FSubOp : LLVM_FloatArithmeticOp<"fsub", "CreateFSub">;
|
|||
def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "CreateFMul">;
|
||||
def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "CreateFDiv">;
|
||||
def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "CreateFRem">;
|
||||
def LLVM_FNegOp : LLVM_UnaryArithmeticOp<LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
|
||||
"fneg", "CreateFNeg">;
|
||||
def LLVM_FNegOp : LLVM_UnaryFloatArithmeticOp<
|
||||
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>, "fneg", "CreateFNeg">;
|
||||
|
||||
// Common code definition that is used to verify and set the alignment attribute
|
||||
// of LLVM ops that accept such an attribute.
|
||||
|
@ -405,7 +444,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
|
|||
let printer = [{ printLandingpadOp(p, *this); }];
|
||||
}
|
||||
|
||||
def LLVM_CallOp : LLVM_Op<"call"> {
|
||||
def LLVM_CallOp : LLVM_Op<"call",
|
||||
[DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
|
||||
let summary = "Call to an LLVM function.";
|
||||
let description = [{
|
||||
|
||||
|
@ -436,7 +476,8 @@ def LLVM_CallOp : LLVM_Op<"call"> {
|
|||
```
|
||||
}];
|
||||
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
|
||||
Variadic<LLVM_Type>);
|
||||
Variadic<LLVM_Type>,
|
||||
DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
|
||||
let results = (outs Variadic<LLVM_Type>);
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
//===-- LLVMOpsInterfaces.td - LLVM op interfaces ----------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the LLVM IR interfaces definition file.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLVM_OPS_INTERFACES
|
||||
#define LLVM_OPS_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
|
||||
let description = [{
|
||||
Access to op fastmath flags.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir::LLVM";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", "fastmathFlags">,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // LLVM_OPS_INTERFACES
|
|
@ -828,7 +828,8 @@ public:
|
|||
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
|
||||
operation, dstType,
|
||||
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
|
||||
operation.operand1(), operation.operand2());
|
||||
operation.operand1(), operation.operand2(),
|
||||
LLVM::FMFAttr::get({}, operation.getContext()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1836,10 +1836,11 @@ struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
|
|||
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
|
||||
|
||||
// Emit IR to add complex numbers.
|
||||
auto fmf = LLVM::FMFAttr::get({}, op.getContext());
|
||||
Value real =
|
||||
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real());
|
||||
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
|
||||
Value imag =
|
||||
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag());
|
||||
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
|
||||
result.setReal(rewriter, loc, real);
|
||||
result.setImaginary(rewriter, loc, imag);
|
||||
|
||||
|
@ -1863,10 +1864,11 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
|
|||
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
|
||||
|
||||
// Emit IR to substract complex numbers.
|
||||
auto fmf = LLVM::FMFAttr::get({}, op.getContext());
|
||||
Value real =
|
||||
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real());
|
||||
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
|
||||
Value imag =
|
||||
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag());
|
||||
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
|
||||
result.setReal(rewriter, loc, real);
|
||||
result.setImaginary(rewriter, loc, imag);
|
||||
|
||||
|
@ -3155,11 +3157,12 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
CmpFOpAdaptor transformed(operands);
|
||||
|
||||
auto fmf = LLVM::FMFAttr::get({}, cmpfOp.getContext());
|
||||
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
|
||||
cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
|
||||
rewriter.getI64IntegerAttr(static_cast<int64_t>(
|
||||
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
|
||||
transformed.lhs(), transformed.rhs());
|
||||
transformed.lhs(), transformed.rhs(), fmf);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRLLVMIR
|
|||
|
||||
DEPENDS
|
||||
MLIRLLVMOpsIncGen
|
||||
MLIRLLVMOpsInterfacesIncGen
|
||||
MLIROpenMPOpsIncGen
|
||||
intrinsics_gen
|
||||
|
||||
|
|
|
@ -36,6 +36,51 @@ static constexpr const char kVolatileAttrName[] = "volatile_";
|
|||
static constexpr const char kNonTemporalAttrName[] = "nontemporal";
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
namespace detail {
|
||||
struct BitmaskEnumStorage : public AttributeStorage {
|
||||
using KeyTy = uint64_t;
|
||||
|
||||
BitmaskEnumStorage(KeyTy val) : value(val) {}
|
||||
|
||||
bool operator==(const KeyTy &key) const { return value == key; }
|
||||
|
||||
static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
return new (allocator.allocate<BitmaskEnumStorage>())
|
||||
BitmaskEnumStorage(key);
|
||||
}
|
||||
|
||||
KeyTy value = 0;
|
||||
};
|
||||
} // namespace detail
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
|
||||
SmallVector<NamedAttribute, 8> filteredAttrs(
|
||||
llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
|
||||
if (attr.first == "fastmathFlags") {
|
||||
auto defAttr = FMFAttr::get({}, attr.second.getContext());
|
||||
return defAttr != attr.second;
|
||||
}
|
||||
return true;
|
||||
}));
|
||||
return filteredAttrs;
|
||||
}
|
||||
|
||||
static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
|
||||
NamedAttrList &result) {
|
||||
return parser.parseOptionalAttrDict(result);
|
||||
}
|
||||
|
||||
static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
|
||||
DictionaryAttr attrs) {
|
||||
printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for LLVM::CmpOp.
|
||||
|
@ -50,7 +95,7 @@ static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
|
|||
static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
|
||||
p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
|
||||
<< "\" " << op.getOperand(0) << ", " << op.getOperand(1);
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
|
||||
p.printOptionalAttrDict(processFMFAttr(op.getAttrs()), {"predicate"});
|
||||
p << " : " << op.lhs().getType();
|
||||
}
|
||||
|
||||
|
@ -771,7 +816,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
|
|||
|
||||
auto args = op.getOperands().drop_front(isDirect ? 0 : 1);
|
||||
p << '(' << args << ')';
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
|
||||
p.printOptionalAttrDict(processFMFAttr(op.getAttrs()), {"callee"});
|
||||
|
||||
// Reconstruct the function MLIR function type from operand and result types.
|
||||
p << " : "
|
||||
|
@ -2041,6 +2086,8 @@ static LogicalResult verify(FenceOp &op) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LLVMDialect::initialize() {
|
||||
addAttributes<FMFAttr>();
|
||||
|
||||
// clang-format off
|
||||
addTypes<LLVMVoidType,
|
||||
LLVMHalfType,
|
||||
|
@ -2172,3 +2219,87 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
|
|||
return op->hasTrait<OpTrait::SymbolTable>() &&
|
||||
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
|
||||
}
|
||||
|
||||
FMFAttr FMFAttr::get(FastmathFlags flags, MLIRContext *context) {
|
||||
return Base::get(context, static_cast<uint64_t>(flags));
|
||||
}
|
||||
|
||||
FastmathFlags FMFAttr::getFlags() const {
|
||||
return static_cast<FastmathFlags>(getImpl()->value);
|
||||
}
|
||||
|
||||
static constexpr const FastmathFlags FastmathFlagsList[] = {
|
||||
// clang-format off
|
||||
FastmathFlags::nnan,
|
||||
FastmathFlags::ninf,
|
||||
FastmathFlags::nsz,
|
||||
FastmathFlags::arcp,
|
||||
FastmathFlags::contract,
|
||||
FastmathFlags::afn,
|
||||
FastmathFlags::reassoc,
|
||||
FastmathFlags::fast,
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
void FMFAttr::print(DialectAsmPrinter &printer) const {
|
||||
printer << "fastmath<";
|
||||
auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) {
|
||||
return bitEnumContains(getFlags(), flag);
|
||||
});
|
||||
llvm::interleaveComma(flags, printer,
|
||||
[&](auto flag) { printer << stringifyEnum(flag); });
|
||||
printer << ">";
|
||||
}
|
||||
|
||||
Attribute FMFAttr::parse(DialectAsmParser &parser) {
|
||||
if (failed(parser.parseLess()))
|
||||
return {};
|
||||
|
||||
FastmathFlags flags = {};
|
||||
if (failed(parser.parseOptionalGreater())) {
|
||||
do {
|
||||
StringRef elemName;
|
||||
if (failed(parser.parseKeyword(&elemName)))
|
||||
return {};
|
||||
|
||||
auto elem = symbolizeFastmathFlags(elemName);
|
||||
if (!elem) {
|
||||
parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
|
||||
<< elemName;
|
||||
return {};
|
||||
}
|
||||
|
||||
flags = flags | *elem;
|
||||
} while (succeeded(parser.parseOptionalComma()));
|
||||
|
||||
if (failed(parser.parseGreater()))
|
||||
return {};
|
||||
}
|
||||
|
||||
return FMFAttr::get(flags, parser.getBuilder().getContext());
|
||||
}
|
||||
|
||||
Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
|
||||
Type type) const {
|
||||
if (type) {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected type");
|
||||
return {};
|
||||
}
|
||||
StringRef attrKind;
|
||||
if (parser.parseKeyword(&attrKind))
|
||||
return {};
|
||||
|
||||
if (attrKind == "fastmath")
|
||||
return FMFAttr::parse(parser);
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "Unknown attrribute type: ")
|
||||
<< attrKind;
|
||||
return {};
|
||||
}
|
||||
|
||||
void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
|
||||
if (auto fmf = attr.dyn_cast<FMFAttr>())
|
||||
fmf.print(os);
|
||||
else
|
||||
llvm_unreachable("Unknown attribute type");
|
||||
}
|
||||
|
|
|
@ -666,6 +666,29 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
|
|||
});
|
||||
}
|
||||
|
||||
static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
|
||||
using llvmFMF = llvm::FastMathFlags;
|
||||
using FuncT = void (llvmFMF::*)(bool);
|
||||
const std::pair<FastmathFlags, FuncT> handlers[] = {
|
||||
// clang-format off
|
||||
{FastmathFlags::nnan, &llvmFMF::setNoNaNs},
|
||||
{FastmathFlags::ninf, &llvmFMF::setNoInfs},
|
||||
{FastmathFlags::nsz, &llvmFMF::setNoSignedZeros},
|
||||
{FastmathFlags::arcp, &llvmFMF::setAllowReciprocal},
|
||||
{FastmathFlags::contract, &llvmFMF::setAllowContract},
|
||||
{FastmathFlags::afn, &llvmFMF::setApproxFunc},
|
||||
{FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
|
||||
{FastmathFlags::fast, &llvmFMF::setFast},
|
||||
// clang-format on
|
||||
};
|
||||
llvm::FastMathFlags ret;
|
||||
auto fmf = op.fastmathFlags();
|
||||
for (auto it : handlers)
|
||||
if (bitEnumContains(fmf, it.first))
|
||||
(ret.*(it.second))(true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Given a single MLIR operation, create the corresponding LLVM IR operation
|
||||
/// using the `builder`. LLVM IR Builder does not have a generic interface so
|
||||
/// this has to be a long chain of `if`s calling different functions with a
|
||||
|
@ -680,6 +703,10 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
|
|||
return position;
|
||||
};
|
||||
|
||||
llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
|
||||
if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
|
||||
builder.setFastMathFlags(getFastmathFlags(fmf));
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
|
||||
|
||||
// Emit function calls. If the "callee" attribute is present, this is a
|
||||
|
|
|
@ -387,3 +387,35 @@ llvm.func @useInlineAsm(%arg0: !llvm.i32) {
|
|||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fastmathFlags
|
||||
func @fastmathFlags(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32) {
|
||||
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
// CHECK: {{.*}} = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
// CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
// CHECK: {{.*}} = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
// CHECK: {{.*}} = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
%0 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
%1 = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
%2 = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
%3 = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
%4 = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
|
||||
// CHECK: {{.*}} = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
%5 = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
|
||||
// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
%6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
|
||||
|
||||
// CHECK: {{.*}} = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (!llvm.i32) -> !llvm.struct<(i32, double, i32)>
|
||||
%7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (!llvm.i32) -> !llvm.struct<(i32, double, i32)>
|
||||
|
||||
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : !llvm.float
|
||||
%8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : !llvm.float
|
||||
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
|
||||
%9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
|
||||
|
||||
// CHECK: {{.*}} = llvm.fneg %arg0 : !llvm.float
|
||||
%10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : !llvm.float
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1360,6 +1360,50 @@ llvm.func @useInlineAsm(%arg0: !llvm.i32) {
|
|||
|
||||
// -----
|
||||
|
||||
llvm.func @fastmathFlagsFunc(!llvm.float) -> !llvm.float
|
||||
|
||||
// CHECK-LABEL: @fastmathFlags
|
||||
llvm.func @fastmathFlags(%arg0: !llvm.float) {
|
||||
// CHECK: {{.*}} = fadd nnan ninf float {{.*}}, {{.*}}
|
||||
// CHECK: {{.*}} = fsub nnan ninf float {{.*}}, {{.*}}
|
||||
// CHECK: {{.*}} = fmul nnan ninf float {{.*}}, {{.*}}
|
||||
// CHECK: {{.*}} = fdiv nnan ninf float {{.*}}, {{.*}}
|
||||
// CHECK: {{.*}} = frem nnan ninf float {{.*}}, {{.*}}
|
||||
%0 = llvm.fadd %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
|
||||
%1 = llvm.fsub %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
|
||||
%2 = llvm.fmul %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
|
||||
%3 = llvm.fdiv %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
|
||||
%4 = llvm.frem %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
|
||||
|
||||
// CHECK: {{.*}} = fcmp nnan ninf oeq {{.*}}, {{.*}}
|
||||
%5 = llvm.fcmp "oeq" %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
|
||||
|
||||
// CHECK: {{.*}} = fneg nnan ninf float {{.*}}
|
||||
%6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
|
||||
|
||||
// CHECK: {{.*}} = call float @fastmathFlagsFunc({{.*}})
|
||||
// CHECK: {{.*}} = call nnan float @fastmathFlagsFunc({{.*}})
|
||||
// CHECK: {{.*}} = call ninf float @fastmathFlagsFunc({{.*}})
|
||||
// CHECK: {{.*}} = call nsz float @fastmathFlagsFunc({{.*}})
|
||||
// CHECK: {{.*}} = call arcp float @fastmathFlagsFunc({{.*}})
|
||||
// CHECK: {{.*}} = call contract float @fastmathFlagsFunc({{.*}})
|
||||
// CHECK: {{.*}} = call afn float @fastmathFlagsFunc({{.*}})
|
||||
// CHECK: {{.*}} = call reassoc float @fastmathFlagsFunc({{.*}})
|
||||
// CHECK: {{.*}} = call fast float @fastmathFlagsFunc({{.*}})
|
||||
%8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<>} : (!llvm.float) -> (!llvm.float)
|
||||
%9 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nnan>} : (!llvm.float) -> (!llvm.float)
|
||||
%10 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<ninf>} : (!llvm.float) -> (!llvm.float)
|
||||
%11 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nsz>} : (!llvm.float) -> (!llvm.float)
|
||||
%12 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<arcp>} : (!llvm.float) -> (!llvm.float)
|
||||
%13 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.float) -> (!llvm.float)
|
||||
%14 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<afn>} : (!llvm.float) -> (!llvm.float)
|
||||
%15 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<reassoc>} : (!llvm.float) -> (!llvm.float)
|
||||
%16 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (!llvm.float) -> (!llvm.float)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @switch_args
|
||||
llvm.func @switch_args(%arg0: !llvm.i32) {
|
||||
%0 = llvm.mlir.constant(5 : i32) : !llvm.i32
|
||||
|
|
Loading…
Reference in New Issue