diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 61f42af2251d..a458837f77a3 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -870,37 +870,37 @@ PYBIND11_MODULE(pybind, m) { .def("__lt__", [](PythonValueHandle lhs, PythonValueHandle rhs) -> PythonValueHandle { - return ValueHandle::create(CmpIPredicate::SLT, lhs.value, + return ValueHandle::create(CmpIPredicate::slt, lhs.value, rhs.value); }) .def("__le__", [](PythonValueHandle lhs, PythonValueHandle rhs) -> PythonValueHandle { - return ValueHandle::create(CmpIPredicate::SLE, lhs.value, + return ValueHandle::create(CmpIPredicate::sle, lhs.value, rhs.value); }) .def("__gt__", [](PythonValueHandle lhs, PythonValueHandle rhs) -> PythonValueHandle { - return ValueHandle::create(CmpIPredicate::SGT, lhs.value, + return ValueHandle::create(CmpIPredicate::sgt, lhs.value, rhs.value); }) .def("__ge__", [](PythonValueHandle lhs, PythonValueHandle rhs) -> PythonValueHandle { - return ValueHandle::create(CmpIPredicate::SGE, lhs.value, + return ValueHandle::create(CmpIPredicate::sge, lhs.value, rhs.value); }) .def("__eq__", [](PythonValueHandle lhs, PythonValueHandle rhs) -> PythonValueHandle { - return ValueHandle::create(CmpIPredicate::EQ, lhs.value, + return ValueHandle::create(CmpIPredicate::eq, lhs.value, rhs.value); }) .def("__ne__", [](PythonValueHandle lhs, PythonValueHandle rhs) -> PythonValueHandle { - return ValueHandle::create(CmpIPredicate::NE, lhs.value, + return ValueHandle::create(CmpIPredicate::ne, lhs.value, rhs.value); }) .def("__invert__", diff --git a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt index 670676f24db5..b6534797a065 100644 --- a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt @@ -1,4 +1,6 @@ set(LLVM_TARGET_DEFINITIONS Ops.td) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRStandardOpsIncGen) diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.h b/mlir/include/mlir/Dialect/StandardOps/Ops.h index fd69534a14d3..77981629710b 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.h @@ -30,6 +30,9 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" +// Pull in all enum type definitions and utility function declarations. +#include "mlir/Dialect/StandardOps/OpsEnums.h.inc" + namespace mlir { class AffineMap; class Builder; @@ -42,27 +45,6 @@ public: static StringRef getDialectNamespace() { return "std"; } }; -/// The predicate indicates the type of the comparison to perform: -/// (in)equality; (un)signed less/greater than (or equal to). -enum class CmpIPredicate { - FirstValidValue, - // (In)equality comparisons. - EQ = FirstValidValue, - NE, - // Signed comparisons. - SLT, - SLE, - SGT, - SGE, - // Unsigned comparisons. - ULT, - ULE, - UGT, - UGE, - // Number of predicates. - NumPredicates -}; - /// The predicate indicates the type of the comparison to perform: /// (un)orderedness, (in)equality and less/greater than (or equal to) as /// well as predicates that are always true or false. diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index bfd3916d8708..eb7ebbb8f6e3 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -345,6 +345,24 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> { let hasCanonicalizer = 1; } +def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>; +def CMPI_P_NE : I64EnumAttrCase<"ne", 1>; +def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>; +def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>; +def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>; +def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>; +def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>; +def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>; +def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>; +def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>; + +def CmpIPredicateAttr : I64EnumAttr< + "CmpIPredicate", "", + [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT, + CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> { + let cppNamespace = "::mlir"; +} + def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> { let summary = "integer comparison operation"; @@ -369,7 +387,11 @@ def CmpIOp : Std_Op<"cmpi", %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1 }]; - let arguments = (ins IntegerLike:$lhs, IntegerLike:$rhs); + let arguments = (ins + CmpIPredicateAttr:$predicate, + IntegerLike:$lhs, + IntegerLike:$rhs + ); let results = (outs BoolLike); let builders = [OpBuilder< @@ -388,6 +410,8 @@ def CmpIOp : Std_Op<"cmpi", } }]; + let verifier = [{ return success(); }]; + let hasFolder = 1; } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 98d80ede2384..4935d2da3fb3 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -97,7 +97,7 @@ public: Value *remainder = builder.create(loc, lhs, rhs); Value *zeroCst = builder.create(loc, 0); Value *isRemainderNegative = - builder.create(loc, CmpIPredicate::SLT, remainder, zeroCst); + builder.create(loc, CmpIPredicate::slt, remainder, zeroCst); Value *correctedRemainder = builder.create(loc, remainder, rhs); Value *result = builder.create(loc, isRemainderNegative, correctedRemainder, remainder); @@ -134,7 +134,7 @@ public: Value *zeroCst = builder.create(loc, 0); Value *noneCst = builder.create(loc, -1); Value *negative = - builder.create(loc, CmpIPredicate::SLT, lhs, zeroCst); + builder.create(loc, CmpIPredicate::slt, lhs, zeroCst); Value *negatedDecremented = builder.create(loc, noneCst, lhs); Value *dividend = builder.create(loc, negative, negatedDecremented, lhs); @@ -173,7 +173,7 @@ public: Value *zeroCst = builder.create(loc, 0); Value *oneCst = builder.create(loc, 1); Value *nonPositive = - builder.create(loc, CmpIPredicate::SLE, lhs, zeroCst); + builder.create(loc, CmpIPredicate::sle, lhs, zeroCst); Value *negated = builder.create(loc, zeroCst, lhs); Value *decremented = builder.create(loc, lhs, oneCst); Value *dividend = @@ -277,7 +277,7 @@ Value *mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { boundOperands); if (!lbValues) return nullptr; - return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SGT, *lbValues, + return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::sgt, *lbValues, builder); } @@ -290,7 +290,7 @@ Value *mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) { boundOperands); if (!ubValues) return nullptr; - return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SLT, *ubValues, + return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::slt, *ubValues, builder); } @@ -352,7 +352,7 @@ public: operandsRef.drop_front(numDims)); if (!affResult) return matchFailure(); - auto pred = isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE; + auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge; Value *cmpVal = rewriter.create(loc, pred, affResult, zeroConstant); cond = diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp index e8ab53b0f6fa..08ee320f7d98 100644 --- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -205,7 +205,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); auto comparison = - rewriter.create(loc, CmpIPredicate::SLT, iv, upperBound); + rewriter.create(loc, CmpIPredicate::slt, iv, upperBound); rewriter.create(loc, comparison, firstBodyBlock, ArrayRef(), endBlock, diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index 2fd6e757ee80..74d1352d19a8 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -90,12 +90,12 @@ public: cmpIOpOperands.rhs()); \ return matchSuccess(); - DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp); - DISPATCH(CmpIPredicate::NE, spirv::INotEqualOp); - DISPATCH(CmpIPredicate::SLT, spirv::SLessThanOp); - DISPATCH(CmpIPredicate::SLE, spirv::SLessThanEqualOp); - DISPATCH(CmpIPredicate::SGT, spirv::SGreaterThanOp); - DISPATCH(CmpIPredicate::SGE, spirv::SGreaterThanEqualOp); + DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); + DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp); + DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp); + DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); + DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); + DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); #undef DISPATCH diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index bf0cb75b8bcc..c4abee3858e8 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -35,6 +35,9 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +// Pull in all enum type definitions and utility function declarations. +#include "mlir/Dialect/StandardOps/OpsEnums.cpp.inc" + using namespace mlir; //===----------------------------------------------------------------------===// @@ -699,43 +702,6 @@ static Type getI1SameShape(Builder *build, Type type) { // CmpIOp //===----------------------------------------------------------------------===// -// Returns an array of mnemonics for CmpIPredicates indexed by values thereof. -static inline const char *const *getCmpIPredicateNames() { - static const char *predicateNames[]{ - /*EQ*/ "eq", - /*NE*/ "ne", - /*SLT*/ "slt", - /*SLE*/ "sle", - /*SGT*/ "sgt", - /*SGE*/ "sge", - /*ULT*/ "ult", - /*ULE*/ "ule", - /*UGT*/ "ugt", - /*UGE*/ "uge", - }; - static_assert(std::extent::value == - (size_t)CmpIPredicate::NumPredicates, - "wrong number of predicate names"); - return predicateNames; -} - -// Returns a value of the predicate corresponding to the given mnemonic. -// Returns NumPredicates (one-past-end) if there is no such mnemonic. -CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { - return llvm::StringSwitch(name) - .Case("eq", CmpIPredicate::EQ) - .Case("ne", CmpIPredicate::NE) - .Case("slt", CmpIPredicate::SLT) - .Case("sle", CmpIPredicate::SLE) - .Case("sgt", CmpIPredicate::SGT) - .Case("sge", CmpIPredicate::SGE) - .Case("ult", CmpIPredicate::ULT) - .Case("ule", CmpIPredicate::ULE) - .Case("ugt", CmpIPredicate::UGT) - .Case("uge", CmpIPredicate::UGE) - .Default(CmpIPredicate::NumPredicates); -} - static void buildCmpIOp(Builder *build, OperationState &result, CmpIPredicate predicate, Value *lhs, Value *rhs) { result.addOperands({lhs, rhs}); @@ -763,8 +729,8 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) { // Rewrite string attribute to an enum value. StringRef predicateName = predicateNameAttr.cast().getValue(); - auto predicate = CmpIOp::getPredicateByName(predicateName); - if (predicate == CmpIPredicate::NumPredicates) + Optional predicate = symbolizeCmpIPredicate(predicateName); + if (!predicate.hasValue()) return parser.emitError(parser.getNameLoc()) << "unknown comparison predicate \"" << predicateName << "\""; @@ -774,7 +740,7 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) { return parser.emitError(parser.getNameLoc(), "expected type with valid i1 shape"); - attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); + attrs[0].second = builder.getI64IntegerAttr(static_cast(*predicate)); result.attributes = attrs; result.addTypes({i1Type}); @@ -784,15 +750,11 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) { static void print(OpAsmPrinter &p, CmpIOp op) { p << "cmpi "; + Builder b(op.getContext()); auto predicateValue = op.getAttrOfType(CmpIOp::getPredicateAttrName()).getInt(); - assert(predicateValue >= static_cast(CmpIPredicate::FirstValidValue) && - predicateValue < static_cast(CmpIPredicate::NumPredicates) && - "unknown predicate index"); - Builder b(op.getContext()); - auto predicateStringAttr = - b.getStringAttr(getCmpIPredicateNames()[predicateValue]); - p.printAttribute(predicateStringAttr); + p << '"' << stringifyCmpIPredicate(static_cast(predicateValue)) + << '"'; p << ", "; p.printOperand(op.lhs()); @@ -803,43 +765,30 @@ static void print(OpAsmPrinter &p, CmpIOp op) { p << " : " << op.lhs()->getType(); } -static LogicalResult verify(CmpIOp op) { - auto predicateAttr = - op.getAttrOfType(CmpIOp::getPredicateAttrName()); - if (!predicateAttr) - return op.emitOpError("requires an integer attribute named 'predicate'"); - auto predicate = predicateAttr.getInt(); - if (predicate < (int64_t)CmpIPredicate::FirstValidValue || - predicate >= (int64_t)CmpIPredicate::NumPredicates) - return op.emitOpError("'predicate' attribute value out of range"); - - return success(); -} - // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer // comparison predicates. static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, const APInt &rhs) { switch (predicate) { - case CmpIPredicate::EQ: + case CmpIPredicate::eq: return lhs.eq(rhs); - case CmpIPredicate::NE: + case CmpIPredicate::ne: return lhs.ne(rhs); - case CmpIPredicate::SLT: + case CmpIPredicate::slt: return lhs.slt(rhs); - case CmpIPredicate::SLE: + case CmpIPredicate::sle: return lhs.sle(rhs); - case CmpIPredicate::SGT: + case CmpIPredicate::sgt: return lhs.sgt(rhs); - case CmpIPredicate::SGE: + case CmpIPredicate::sge: return lhs.sge(rhs); - case CmpIPredicate::ULT: + case CmpIPredicate::ult: return lhs.ult(rhs); - case CmpIPredicate::ULE: + case CmpIPredicate::ule: return lhs.ule(rhs); - case CmpIPredicate::UGT: + case CmpIPredicate::ugt: return lhs.ugt(rhs); - case CmpIPredicate::UGE: + case CmpIPredicate::uge: return lhs.uge(rhs); default: llvm_unreachable("unknown comparison predicate"); diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 4046a7c1fc67..9d7ca8ca99b7 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -460,13 +460,13 @@ ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs) - : createIComparisonExpr(CmpIPredicate::EQ, lhs, rhs); + : createIComparisonExpr(CmpIPredicate::eq, lhs, rhs); } ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs) - : createIComparisonExpr(CmpIPredicate::NE, lhs, rhs); + : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs); } ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) { auto type = lhs.getType(); @@ -474,23 +474,23 @@ ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) { ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) : // TODO(ntv,zinenko): signed by default, how about unsigned? - createIComparisonExpr(CmpIPredicate::SLT, lhs, rhs); + createIComparisonExpr(CmpIPredicate::slt, lhs, rhs); } ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) - : createIComparisonExpr(CmpIPredicate::SLE, lhs, rhs); + : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs); } ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) - : createIComparisonExpr(CmpIPredicate::SGT, lhs, rhs); + : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs); } ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) - : createIComparisonExpr(CmpIPredicate::SGE, lhs, rhs); + : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs); } diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 405116e72e7c..0ee1220b7201 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -747,7 +747,7 @@ static Loops stripmineSink(loop::ForOp forOp, Value *factor, // Insert newForOp before the terminator of `t`. OpBuilder b(t.getBodyBuilder()); Value *stepped = b.create(t.getLoc(), iv, forOp.step()); - Value *less = b.create(t.getLoc(), CmpIPredicate::SLT, + Value *less = b.create(t.getLoc(), CmpIPredicate::slt, forOp.upperBound(), stepped); Value *ub = b.create(t.getLoc(), less, forOp.upperBound(), stepped); diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 74dd41294229..c8fffc9c88b9 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -201,7 +201,7 @@ func @func_with_ops(i32) { func @func_with_ops(i32) { ^bb0(%a : i32): - // expected-error@+1 {{'predicate' attribute value out of range}} + // expected-error@+1 {{failed to satisfy constraint: allowed 64-bit integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}} %r = "std.cmpi"(%a, %a) {predicate = 42} : (i32, i32) -> i1 } @@ -241,7 +241,7 @@ func @func_with_ops(i32, i32) { func @func_with_ops(i32, i32) { ^bb0(%a : i32, %b : i32): - // expected-error@+1 {{requires an integer attribute named 'predicate'}} + // expected-error@+1 {{requires attribute 'predicate'}} %r = "std.cmpi"(%a, %b) {foo = 1} : (i32, i32) -> i1 }