forked from OSchip/llvm-project
[mlir] Generate CmpFPredicate as an EnumAttr in tablegen
Summary: This allows for attaching the attribute to CmpF as a proper argument, and thus enables the removal of a bunch of c++ code. Differential Revision: https://reviews.llvm.org/D75539
This commit is contained in:
parent
5d3a995938
commit
c10896682d
|
@ -40,37 +40,6 @@ public:
|
|||
Location loc) override;
|
||||
};
|
||||
|
||||
/// The predicate indicates the type of the comparison to perform:
|
||||
/// (un)orderedness, (in)equality and less/greater than (or equal to) as
|
||||
/// well as predicates that are always true or false.
|
||||
enum class CmpFPredicate {
|
||||
FirstValidValue,
|
||||
// Always false
|
||||
AlwaysFalse = FirstValidValue,
|
||||
// Ordered comparisons
|
||||
OEQ,
|
||||
OGT,
|
||||
OGE,
|
||||
OLT,
|
||||
OLE,
|
||||
ONE,
|
||||
// Both ordered
|
||||
ORD,
|
||||
// Unordered comparisons
|
||||
UEQ,
|
||||
UGT,
|
||||
UGE,
|
||||
ULT,
|
||||
ULE,
|
||||
UNE,
|
||||
// Any unordered
|
||||
UNO,
|
||||
// Always true
|
||||
AlwaysTrue,
|
||||
// Number of predicates.
|
||||
NumPredicates
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h.inc"
|
||||
|
||||
|
|
|
@ -433,6 +433,34 @@ def CeilFOp : FloatUnaryOp<"ceilf"> {
|
|||
}];
|
||||
}
|
||||
|
||||
// The predicate indicates the type of the comparison to perform:
|
||||
// (un)orderedness, (in)equality and less/greater than (or equal to) as
|
||||
// well as predicates that are always true or false.
|
||||
def CMPF_P_FALSE : I64EnumAttrCase<"AlwaysFalse", 0, "false">;
|
||||
def CMPF_P_OEQ : I64EnumAttrCase<"OEQ", 1, "oeq">;
|
||||
def CMPF_P_OGT : I64EnumAttrCase<"OGT", 2, "ogt">;
|
||||
def CMPF_P_OGE : I64EnumAttrCase<"OGE", 3, "oge">;
|
||||
def CMPF_P_OLT : I64EnumAttrCase<"OLT", 4, "olt">;
|
||||
def CMPF_P_OLE : I64EnumAttrCase<"OLE", 5, "ole">;
|
||||
def CMPF_P_ONE : I64EnumAttrCase<"ONE", 6, "one">;
|
||||
def CMPF_P_ORD : I64EnumAttrCase<"ORD", 7, "ord">;
|
||||
def CMPF_P_UEQ : I64EnumAttrCase<"UEQ", 8, "ueq">;
|
||||
def CMPF_P_UGT : I64EnumAttrCase<"UGT", 9, "ugt">;
|
||||
def CMPF_P_UGE : I64EnumAttrCase<"UGE", 10, "uge">;
|
||||
def CMPF_P_ULT : I64EnumAttrCase<"ULT", 11, "ult">;
|
||||
def CMPF_P_ULE : I64EnumAttrCase<"ULE", 12, "ule">;
|
||||
def CMPF_P_UNE : I64EnumAttrCase<"UNE", 13, "une">;
|
||||
def CMPF_P_UNO : I64EnumAttrCase<"UNO", 14, "uno">;
|
||||
def CMPF_P_TRUE : I64EnumAttrCase<"AlwaysTrue", 15, "true">;
|
||||
|
||||
def CmpFPredicateAttr : I64EnumAttr<
|
||||
"CmpFPredicate", "",
|
||||
[CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE,
|
||||
CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT,
|
||||
CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> {
|
||||
let cppNamespace = "::mlir";
|
||||
}
|
||||
|
||||
def CmpFOp : Std_Op<"cmpf",
|
||||
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
|
||||
TypesMatchWith<
|
||||
|
@ -461,7 +489,11 @@ def CmpFOp : Std_Op<"cmpf",
|
|||
%r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1
|
||||
}];
|
||||
|
||||
let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
|
||||
let arguments = (ins
|
||||
CmpFPredicateAttr:$predicate,
|
||||
FloatLike:$lhs,
|
||||
FloatLike:$rhs
|
||||
);
|
||||
let results = (outs BoolLike:$result);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
|
@ -480,7 +512,11 @@ def CmpFOp : Std_Op<"cmpf",
|
|||
}
|
||||
}];
|
||||
|
||||
let verifier = [{ return success(); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
|
||||
}
|
||||
|
||||
def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>;
|
||||
|
|
|
@ -580,55 +580,6 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
|
|||
// CmpFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Returns an array of mnemonics for CmpFPredicates indexed by values thereof.
|
||||
static inline const char *const *getCmpFPredicateNames() {
|
||||
static const char *predicateNames[] = {
|
||||
/*AlwaysFalse*/ "false",
|
||||
/*OEQ*/ "oeq",
|
||||
/*OGT*/ "ogt",
|
||||
/*OGE*/ "oge",
|
||||
/*OLT*/ "olt",
|
||||
/*OLE*/ "ole",
|
||||
/*ONE*/ "one",
|
||||
/*ORD*/ "ord",
|
||||
/*UEQ*/ "ueq",
|
||||
/*UGT*/ "ugt",
|
||||
/*UGE*/ "uge",
|
||||
/*ULT*/ "ult",
|
||||
/*ULE*/ "ule",
|
||||
/*UNE*/ "une",
|
||||
/*UNO*/ "uno",
|
||||
/*AlwaysTrue*/ "true",
|
||||
};
|
||||
static_assert(std::extent<decltype(predicateNames)>::value ==
|
||||
(size_t)CmpFPredicate::NumPredicates,
|
||||
"wrong number of predicate names");
|
||||
return predicateNames;
|
||||
}
|
||||
|
||||
// Returns a value of the predicate corresponding to the given mnemonic.
|
||||
// Returns NumPredicates (one-past-end) if there is no such mnemonic.
|
||||
CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
|
||||
return llvm::StringSwitch<CmpFPredicate>(name)
|
||||
.Case("false", CmpFPredicate::AlwaysFalse)
|
||||
.Case("oeq", CmpFPredicate::OEQ)
|
||||
.Case("ogt", CmpFPredicate::OGT)
|
||||
.Case("oge", CmpFPredicate::OGE)
|
||||
.Case("olt", CmpFPredicate::OLT)
|
||||
.Case("ole", CmpFPredicate::OLE)
|
||||
.Case("one", CmpFPredicate::ONE)
|
||||
.Case("ord", CmpFPredicate::ORD)
|
||||
.Case("ueq", CmpFPredicate::UEQ)
|
||||
.Case("ugt", CmpFPredicate::UGT)
|
||||
.Case("uge", CmpFPredicate::UGE)
|
||||
.Case("ult", CmpFPredicate::ULT)
|
||||
.Case("ule", CmpFPredicate::ULE)
|
||||
.Case("une", CmpFPredicate::UNE)
|
||||
.Case("uno", CmpFPredicate::UNO)
|
||||
.Case("true", CmpFPredicate::AlwaysTrue)
|
||||
.Default(CmpFPredicate::NumPredicates);
|
||||
}
|
||||
|
||||
static void buildCmpFOp(Builder *build, OperationState &result,
|
||||
CmpFPredicate predicate, Value lhs, Value rhs) {
|
||||
result.addOperands({lhs, rhs});
|
||||
|
@ -638,73 +589,8 @@ static void buildCmpFOp(Builder *build, OperationState &result,
|
|||
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
|
||||
}
|
||||
|
||||
static ParseResult parseCmpFOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> ops;
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
Attribute predicateNameAttr;
|
||||
Type type;
|
||||
if (parser.parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(),
|
||||
attrs) ||
|
||||
parser.parseComma() || parser.parseOperandList(ops, 2) ||
|
||||
parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) ||
|
||||
parser.resolveOperands(ops, type, result.operands))
|
||||
return failure();
|
||||
|
||||
if (!predicateNameAttr.isa<StringAttr>())
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected string comparison predicate attribute");
|
||||
|
||||
// Rewrite string attribute to an enum value.
|
||||
StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
|
||||
auto predicate = CmpFOp::getPredicateByName(predicateName);
|
||||
if (predicate == CmpFPredicate::NumPredicates)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"unknown comparison predicate \"" + predicateName +
|
||||
"\"");
|
||||
|
||||
auto builder = parser.getBuilder();
|
||||
Type i1Type = getCheckedI1SameShape(type);
|
||||
if (!i1Type)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected type with valid i1 shape");
|
||||
|
||||
attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
|
||||
result.attributes = attrs;
|
||||
|
||||
result.addTypes({i1Type});
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, CmpFOp op) {
|
||||
p << "cmpf ";
|
||||
|
||||
auto predicateValue =
|
||||
op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt();
|
||||
assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
|
||||
predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
|
||||
"unknown predicate index");
|
||||
p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs()
|
||||
<< ", " << op.rhs();
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
|
||||
p << " : " << op.lhs().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(CmpFOp op) {
|
||||
auto predicateAttr =
|
||||
op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName());
|
||||
if (!predicateAttr)
|
||||
return op.emitOpError("requires an integer attribute named 'predicate'");
|
||||
auto predicate = predicateAttr.getInt();
|
||||
if (predicate < (int64_t)CmpFPredicate::FirstValidValue ||
|
||||
predicate >= (int64_t)CmpFPredicate::NumPredicates)
|
||||
return op.emitOpError("'predicate' attribute value out of range");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
|
||||
// comparison predicates.
|
||||
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
|
||||
/// comparison predicates.
|
||||
static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
|
||||
const APFloat &rhs) {
|
||||
auto cmpResult = lhs.compare(rhs);
|
||||
|
|
|
@ -346,28 +346,28 @@ func @invalid_cmp_attr(%idx : i32) {
|
|||
// -----
|
||||
|
||||
func @cmpf_generic_invalid_predicate_value(%a : f32) {
|
||||
// expected-error@+1 {{'predicate' attribute value out of range}}
|
||||
// expected-error@+1 {{attribute 'predicate' failed to satisfy constraint: allowed 64-bit integer cases}}
|
||||
%r = "std.cmpf"(%a, %a) {predicate = 42} : (f32, f32) -> i1
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cmpf_canonical_invalid_predicate_value(%a : f32) {
|
||||
// expected-error@+1 {{unknown comparison predicate "foo"}}
|
||||
// expected-error@+1 {{invalid predicate attribute specification: "foo"}}
|
||||
%r = cmpf "foo", %a, %a : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cmpf_canonical_invalid_predicate_value_signed(%a : f32) {
|
||||
// expected-error@+1 {{unknown comparison predicate "sge"}}
|
||||
// expected-error@+1 {{invalid predicate attribute specification: "sge"}}
|
||||
%r = cmpf "sge", %a, %a : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cmpf_canonical_invalid_predicate_value_no_order(%a : f32) {
|
||||
// expected-error@+1 {{unknown comparison predicate "eq"}}
|
||||
// expected-error@+1 {{invalid predicate attribute specification: "eq"}}
|
||||
%r = cmpf "eq", %a, %a : f32
|
||||
}
|
||||
|
||||
|
@ -380,14 +380,14 @@ func @cmpf_canonical_no_predicate_attr(%a : f32, %b : f32) {
|
|||
// -----
|
||||
|
||||
func @cmpf_generic_no_predicate_attr(%a : f32, %b : f32) {
|
||||
// expected-error@+1 {{requires an integer attribute named 'predicate'}}
|
||||
// expected-error@+1 {{requires attribute 'predicate'}}
|
||||
%r = "std.cmpf"(%a, %b) {foo = 1} : (f32, f32) -> i1
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cmpf_wrong_type(%a : i32, %b : i32) {
|
||||
%r = cmpf "oeq", %a, %b : i32 // expected-error {{operand #0 must be floating-point-like}}
|
||||
%r = cmpf "oeq", %a, %b : i32 // expected-error {{must be floating-point-like}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue