Move LLVM::FMFAttr definition to TableGen (NFC)

This is using the new Attribute storage generation support in
TableGen to define the LLVM FastMathFlags.

Differential Revision: https://reviews.llvm.org/D98007
This commit is contained in:
Mehdi Amini 2021-03-05 06:38:59 +00:00
parent 1b0819e325
commit 038f2a337d
8 changed files with 58 additions and 56 deletions

View File

@ -1,5 +1,10 @@
add_subdirectory(Transforms)
set(LLVM_TARGET_DEFINITIONS LLVMAttrDefs.td)
mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRLLVMAttrsIncGen)
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)

View File

@ -0,0 +1,29 @@
//===-- LLVMAttrDefs.td - LLVM Attributes definition file --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef LLVMIR_ATTRDEFS
#define LLVMIR_ATTRDEFS
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
// All of the attributes will extend this class.
class LLVM_Attr<string name> : AttrDef<LLVM_Dialect, name>;
// The "FastMath" flags associated with floating point LLVM instructions.
def FastmathFlagsAttr : LLVM_Attr<"FMF"> {
let mnemonic = "fastmath";
// List of type parameters.
let parameters = (
ins
"FastmathFlags":$flags
);
}
#endif // LLVMIR_ATTRDEFS

View File

@ -30,6 +30,8 @@
#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc"
#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc"
namespace llvm {
class Type;
@ -47,24 +49,9 @@ class LLVMDialect;
namespace detail {
struct LLVMTypeStorage;
struct LLVMDialectImpl;
struct BitmaskEnumStorage;
struct LoopOptionAttrStorage;
} // namespace detail
/// An attribute that specifies LLVM instruction fastmath flags.
class FMFAttr : public Attribute::AttrBase<FMFAttr, Attribute,
detail::BitmaskEnumStorage> {
public:
using Base::Base;
static FMFAttr get(FastmathFlags flags, MLIRContext *context);
FastmathFlags getFlags() const;
void print(DialectAsmPrinter &p) const;
static Attribute parse(DialectAsmParser &parser);
};
/// An attribute that specifies LLVM loop codegen options.
class LoopOptionAttr
: public Attribute::AttrBase<LoopOptionAttr, Attribute,

View File

@ -44,7 +44,7 @@ def LLVM_FMFAttr : DialectAttr<
let returnType = "::mlir::LLVM::FastmathFlags";
let convertFromStorage = "$_self.getFlags()";
let constBuilderCall =
"::mlir::LLVM::FMFAttr::get($0, $_builder.getContext())";
"::mlir::LLVM::FMFAttr::get($_builder.getContext(), $0)";
}
def LOptDisableUnroll : I32EnumAttrCase<"disable_unroll", 1>;
@ -249,7 +249,7 @@ def LLVM_FCmpOp : LLVM_Op<"fcmp", [
[{
build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1),
$_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs,
::mlir::LLVM::FMFAttr::get(fmf, $_builder.getContext()));
::mlir::LLVM::FMFAttr::get($_builder.getContext(), fmf));
}]>];
let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
let printer = [{ printFCmpOp(p, *this); }];

View File

@ -30,7 +30,7 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
Value real = complexStruct.real(rewriter, op.getLoc());
Value imag = complexStruct.imaginary(rewriter, op.getLoc());
auto fmf = LLVM::FMFAttr::get({}, op.getContext());
auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
Value sqNorm = rewriter.create<LLVM::FAddOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
@ -133,7 +133,7 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
auto fmf = LLVM::FMFAttr::get({}, op.getContext());
auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
Value real =
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =
@ -161,7 +161,7 @@ struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
auto fmf = LLVM::FMFAttr::get({}, op.getContext());
auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
Value rhsRe = arg.rhs.real();
Value rhsIm = arg.rhs.imag();
Value lhsRe = arg.lhs.real();
@ -206,7 +206,7 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
auto fmf = LLVM::FMFAttr::get({}, op.getContext());
auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
Value rhsRe = arg.rhs.real();
Value rhsIm = arg.rhs.imag();
Value lhsRe = arg.lhs.real();
@ -243,7 +243,7 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to substract complex numbers.
auto fmf = LLVM::FMFAttr::get({}, op.getContext());
auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
Value real =
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =

View File

@ -829,7 +829,7 @@ public:
operation, dstType,
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
operation.operand1(), operation.operand2(),
LLVM::FMFAttr::get({}, operation.getContext()));
LLVM::FMFAttr::get(operation.getContext(), {}));
return success();
}
};

View File

@ -2981,7 +2981,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
ConversionPatternRewriter &rewriter) const override {
CmpFOpAdaptor transformed(operands);
auto fmf = LLVM::FMFAttr::get({}, cmpfOp.getContext());
auto fmf = LLVM::FMFAttr::get(cmpfOp.getContext(), {});
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(

View File

@ -20,6 +20,7 @@
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
@ -37,25 +38,12 @@ static constexpr const char kNonTemporalAttrName[] = "nontemporal";
#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
namespace mlir {
namespace LLVM {
namespace detail {
struct BitmaskEnumStorage : public AttributeStorage {
using KeyTy = uint64_t;
BitmaskEnumStorage(KeyTy val) : value(val) {}
bool operator==(const KeyTy &key) const { return value == key; }
static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<BitmaskEnumStorage>())
BitmaskEnumStorage(key);
}
KeyTy value = 0;
};
struct LoopOptionAttrStorage : public AttributeStorage {
using KeyTy = std::pair<uint64_t, int32_t>;
@ -84,7 +72,7 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
SmallVector<NamedAttribute, 8> filteredAttrs(
llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
if (attr.first == "fastmathFlags") {
auto defAttr = FMFAttr::get({}, attr.second.getContext());
auto defAttr = FMFAttr::get(attr.second.getContext(), {});
return defAttr != attr.second;
}
return true;
@ -2387,14 +2375,6 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}
FMFAttr FMFAttr::get(FastmathFlags flags, MLIRContext *context) {
return Base::get(context, static_cast<uint64_t>(flags));
}
FastmathFlags FMFAttr::getFlags() const {
return static_cast<FastmathFlags>(getImpl()->value);
}
static constexpr const FastmathFlags FastmathFlagsList[] = {
// clang-format off
FastmathFlags::nnan,
@ -2418,7 +2398,8 @@ void FMFAttr::print(DialectAsmPrinter &printer) const {
printer << ">";
}
Attribute FMFAttr::parse(DialectAsmParser &parser) {
Attribute FMFAttr::parse(MLIRContext *context, DialectAsmParser &parser,
Type type) {
if (failed(parser.parseLess()))
return {};
@ -2443,7 +2424,7 @@ Attribute FMFAttr::parse(DialectAsmParser &parser) {
return {};
}
return FMFAttr::get(flags, parser.getBuilder().getContext());
return FMFAttr::get(parser.getBuilder().getContext(), flags);
}
LoopOptionAttr LoopOptionAttr::getDisableUnroll(MLIRContext *context,
@ -2558,9 +2539,9 @@ Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
StringRef attrKind;
if (parser.parseKeyword(&attrKind))
return {};
if (attrKind == "fastmath")
return FMFAttr::parse(parser);
if (auto attr =
generatedAttributeParser(getContext(), parser, attrKind, type))
return attr;
if (attrKind == "loopopt")
return LoopOptionAttr::parse(parser);
@ -2570,9 +2551,9 @@ Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
}
void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
if (auto fmf = attr.dyn_cast<FMFAttr>())
fmf.print(os);
else if (auto lopt = attr.dyn_cast<LoopOptionAttr>())
if (succeeded(generatedAttributePrinter(attr, os)))
return;
if (auto lopt = attr.dyn_cast<LoopOptionAttr>())
lopt.print(os);
else
llvm_unreachable("Unknown attribute type");