LLVM Dialect: define ICmpPredicate in ODS

Use the recently introduced enum-gen functionality to define the predicate
attribute of the ICmp LLVM dialect operation directly in ODS.  This removes the
need for manually-coded string-to-integer conversion functions and contributes
to the overall homogenization of the operation definitions.

PiperOrigin-RevId: 258143923
This commit is contained in:
Alex Zinenko 2019-07-15 05:45:08 -07:00 committed by Mehdi Amini
parent cca53e8527
commit c3d166c532
4 changed files with 34 additions and 38 deletions

View File

@ -1,6 +1,8 @@
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRLLVMOpsIncGen)
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
mlir_tablegen(NVVMOps.h.inc -gen-op-decls)

View File

@ -34,6 +34,8 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "mlir/LLVMIR/LLVMOpsEnums.h.inc"
namespace llvm {
class Type;
class LLVMContext;

View File

@ -133,9 +133,28 @@ def LLVM_AndOp : LLVM_ArithmeticOp<"and", "CreateAnd">;
def LLVM_OrOp : LLVM_ArithmeticOp<"or", "CreateOr">;
def LLVM_XOrOp : LLVM_ArithmeticOp<"xor", "CreateXor">;
def ICmpPredicateEQ : I64EnumAttrCase<"eq", 0>;
def ICmpPredicateNE : I64EnumAttrCase<"ne", 1>;
def ICmpPredicateSLT : I64EnumAttrCase<"slt", 2>;
def ICmpPredicateSLE : I64EnumAttrCase<"sle", 3>;
def ICmpPredicateSGT : I64EnumAttrCase<"sgt", 4>;
def ICmpPredicateSGE : I64EnumAttrCase<"sge", 5>;
def ICmpPredicateULT : I64EnumAttrCase<"ult", 6>;
def ICmpPredicateULE : I64EnumAttrCase<"ule", 7>;
def ICmpPredicateUGT : I64EnumAttrCase<"ugt", 8>;
def ICmpPredicateUGE : I64EnumAttrCase<"uge", 9>;
def ICmpPredicate : I64EnumAttr<
"ICmpPredicate",
"llvm.icmp comparison predicate",
[ICmpPredicateEQ, ICmpPredicateNE, ICmpPredicateSLT, ICmpPredicateSLE,
ICmpPredicateSGT, ICmpPredicateSGE, ICmpPredicateULT, ICmpPredicateULE,
ICmpPredicateUGT, ICmpPredicateUGE]> {
let cppNamespace = "mlir::LLVM";
}
// Other integer operations.
def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>,
Arguments<(ins I64Attr:$predicate, LLVM_Type:$lhs,
Arguments<(ins ICmpPredicate:$predicate, LLVM_Type:$lhs,
LLVM_Type:$rhs)> {
let llvmBuilder = [{
$res = builder.CreateICmp(getLLVMCmpPredicate(

View File

@ -35,46 +35,17 @@
using namespace mlir;
using namespace mlir::LLVM;
#include "mlir/LLVMIR/LLVMOpsEnums.cpp.inc"
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ICmpOp.
//===----------------------------------------------------------------------===//
// Return an array of mnemonics for ICmpPredicates indexed by its value.
static const char *const *getICmpPredicateNames() {
static const char *predicateNames[]{/*EQ*/ "eq",
/*NE*/ "ne",
/*SLT*/ "slt",
/*SLE*/ "sle",
/*SGT*/ "sgt",
/*SGE*/ "sge",
/*ULT*/ "ult",
/*ULE*/ "ule",
/*UGT*/ "ugt",
/*UGE*/ "uge"};
return predicateNames;
}
// Returns a value of the ICmp predicate corresponding to the given mnemonic.
// Returns -1 if there is no such mnemonic.
static int getICmpPredicateByName(StringRef name) {
return llvm::StringSwitch<int>(name)
.Case("eq", 0)
.Case("ne", 1)
.Case("slt", 2)
.Case("sle", 3)
.Case("sgt", 4)
.Case("sge", 5)
.Case("ult", 6)
.Case("ule", 7)
.Case("ugt", 8)
.Case("uge", 9)
.Default(-1);
}
static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) {
*p << op.getOperationName() << " \""
<< getICmpPredicateNames()[op.predicate().getZExtValue()] << "\" "
<< *op.getOperand(0) << ", " << *op.getOperand(1);
<< stringifyICmpPredicate(
static_cast<ICmpPredicate>(op.predicate().getZExtValue()))
<< "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
p->printOptionalAttrDict(op.getAttrs(), {"predicate"});
*p << " : " << op.lhs()->getType();
}
@ -104,13 +75,15 @@ static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) {
if (!predicateStr)
return parser->emitError(predicateLoc,
"expected 'predicate' attribute of string type");
int predicateValue = getICmpPredicateByName(predicateStr.getValue());
if (predicateValue == -1)
Optional<ICmpPredicate> predicateValue =
symbolizeICmpPredicate(predicateStr.getValue());
if (!predicateValue)
return parser->emitError(predicateLoc)
<< "'" << predicateStr.getValue()
<< "' is an incorrect value of the 'predicate' attribute";
attrs[0].second = parser->getBuilder().getI64IntegerAttr(predicateValue);
attrs[0].second = parser->getBuilder().getI64IntegerAttr(
static_cast<int64_t>(predicateValue.getValue()));
// The result type is either i1 or a vector type <? x i1> if the inputs are
// vectors.