diff --git a/mlir/include/mlir/LLVMIR/CMakeLists.txt b/mlir/include/mlir/LLVMIR/CMakeLists.txt index 31771a35f164..1d7d06bc25c1 100644 --- a/mlir/include/mlir/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/LLVMIR/CMakeLists.txt @@ -1,6 +1,8 @@ set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMOps.h.inc -gen-op-decls) mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs) +mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRLLVMOpsIncGen) set(LLVM_TARGET_DEFINITIONS NVVMOps.td) mlir_tablegen(NVVMOps.h.inc -gen-op-decls) diff --git a/mlir/include/mlir/LLVMIR/LLVMDialect.h b/mlir/include/mlir/LLVMIR/LLVMDialect.h index c0b8347e36c4..2f98828b1027 100644 --- a/mlir/include/mlir/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/LLVMIR/LLVMDialect.h @@ -34,6 +34,8 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" +#include "mlir/LLVMIR/LLVMOpsEnums.h.inc" + namespace llvm { class Type; class LLVMContext; diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td index 5581193f62b1..a67f4627c2dc 100644 --- a/mlir/include/mlir/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -133,9 +133,28 @@ def LLVM_AndOp : LLVM_ArithmeticOp<"and", "CreateAnd">; def LLVM_OrOp : LLVM_ArithmeticOp<"or", "CreateOr">; def LLVM_XOrOp : LLVM_ArithmeticOp<"xor", "CreateXor">; +def ICmpPredicateEQ : I64EnumAttrCase<"eq", 0>; +def ICmpPredicateNE : I64EnumAttrCase<"ne", 1>; +def ICmpPredicateSLT : I64EnumAttrCase<"slt", 2>; +def ICmpPredicateSLE : I64EnumAttrCase<"sle", 3>; +def ICmpPredicateSGT : I64EnumAttrCase<"sgt", 4>; +def ICmpPredicateSGE : I64EnumAttrCase<"sge", 5>; +def ICmpPredicateULT : I64EnumAttrCase<"ult", 6>; +def ICmpPredicateULE : I64EnumAttrCase<"ule", 7>; +def ICmpPredicateUGT : I64EnumAttrCase<"ugt", 8>; +def ICmpPredicateUGE : I64EnumAttrCase<"uge", 9>; +def ICmpPredicate : I64EnumAttr< + "ICmpPredicate", + "llvm.icmp comparison predicate", + [ICmpPredicateEQ, ICmpPredicateNE, ICmpPredicateSLT, ICmpPredicateSLE, + ICmpPredicateSGT, ICmpPredicateSGE, ICmpPredicateULT, ICmpPredicateULE, + ICmpPredicateUGT, ICmpPredicateUGE]> { + let cppNamespace = "mlir::LLVM"; +} + // Other integer operations. def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, - Arguments<(ins I64Attr:$predicate, LLVM_Type:$lhs, + Arguments<(ins ICmpPredicate:$predicate, LLVM_Type:$lhs, LLVM_Type:$rhs)> { let llvmBuilder = [{ $res = builder.CreateICmp(getLLVMCmpPredicate( diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 39cd03f51de1..2444eb17df25 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -35,46 +35,17 @@ using namespace mlir; using namespace mlir::LLVM; +#include "mlir/LLVMIR/LLVMOpsEnums.cpp.inc" + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::ICmpOp. //===----------------------------------------------------------------------===// -// Return an array of mnemonics for ICmpPredicates indexed by its value. -static const char *const *getICmpPredicateNames() { - static const char *predicateNames[]{/*EQ*/ "eq", - /*NE*/ "ne", - /*SLT*/ "slt", - /*SLE*/ "sle", - /*SGT*/ "sgt", - /*SGE*/ "sge", - /*ULT*/ "ult", - /*ULE*/ "ule", - /*UGT*/ "ugt", - /*UGE*/ "uge"}; - return predicateNames; -} - -// Returns a value of the ICmp predicate corresponding to the given mnemonic. -// Returns -1 if there is no such mnemonic. -static int getICmpPredicateByName(StringRef name) { - return llvm::StringSwitch(name) - .Case("eq", 0) - .Case("ne", 1) - .Case("slt", 2) - .Case("sle", 3) - .Case("sgt", 4) - .Case("sge", 5) - .Case("ult", 6) - .Case("ule", 7) - .Case("ugt", 8) - .Case("uge", 9) - .Default(-1); -} - static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) { *p << op.getOperationName() << " \"" - << getICmpPredicateNames()[op.predicate().getZExtValue()] << "\" " - << *op.getOperand(0) << ", " << *op.getOperand(1); + << stringifyICmpPredicate( + static_cast(op.predicate().getZExtValue())) + << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1); p->printOptionalAttrDict(op.getAttrs(), {"predicate"}); *p << " : " << op.lhs()->getType(); } @@ -104,13 +75,15 @@ static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) { if (!predicateStr) return parser->emitError(predicateLoc, "expected 'predicate' attribute of string type"); - int predicateValue = getICmpPredicateByName(predicateStr.getValue()); - if (predicateValue == -1) + Optional predicateValue = + symbolizeICmpPredicate(predicateStr.getValue()); + if (!predicateValue) return parser->emitError(predicateLoc) << "'" << predicateStr.getValue() << "' is an incorrect value of the 'predicate' attribute"; - attrs[0].second = parser->getBuilder().getI64IntegerAttr(predicateValue); + attrs[0].second = parser->getBuilder().getI64IntegerAttr( + static_cast(predicateValue.getValue())); // The result type is either i1 or a vector type if the inputs are // vectors.