NFC: Convert CmpIPredicate in StandardOps to use EnumAttr

This turns several hand-written functions to auto-generated ones.

PiperOrigin-RevId: 280684326
This commit is contained in:
Lei Zhang 2019-11-15 10:16:33 -08:00 committed by A. Unique TensorFlower
parent 9d7039b001
commit a0986bf43d
11 changed files with 77 additions and 120 deletions

View File

@ -870,37 +870,37 @@ PYBIND11_MODULE(pybind, m) {
.def("__lt__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::SLT, lhs.value,
return ValueHandle::create<CmpIOp>(CmpIPredicate::slt, lhs.value,
rhs.value);
})
.def("__le__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::SLE, lhs.value,
return ValueHandle::create<CmpIOp>(CmpIPredicate::sle, lhs.value,
rhs.value);
})
.def("__gt__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::SGT, lhs.value,
return ValueHandle::create<CmpIOp>(CmpIPredicate::sgt, lhs.value,
rhs.value);
})
.def("__ge__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::SGE, lhs.value,
return ValueHandle::create<CmpIOp>(CmpIPredicate::sge, lhs.value,
rhs.value);
})
.def("__eq__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::EQ, lhs.value,
return ValueHandle::create<CmpIOp>(CmpIPredicate::eq, lhs.value,
rhs.value);
})
.def("__ne__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::NE, lhs.value,
return ValueHandle::create<CmpIOp>(CmpIPredicate::ne, lhs.value,
rhs.value);
})
.def("__invert__",

View File

@ -1,4 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Ops.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRStandardOpsIncGen)

View File

@ -30,6 +30,9 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/StandardOps/OpsEnums.h.inc"
namespace mlir {
class AffineMap;
class Builder;
@ -42,27 +45,6 @@ public:
static StringRef getDialectNamespace() { return "std"; }
};
/// The predicate indicates the type of the comparison to perform:
/// (in)equality; (un)signed less/greater than (or equal to).
enum class CmpIPredicate {
FirstValidValue,
// (In)equality comparisons.
EQ = FirstValidValue,
NE,
// Signed comparisons.
SLT,
SLE,
SGT,
SGE,
// Unsigned comparisons.
ULT,
ULE,
UGT,
UGE,
// Number of predicates.
NumPredicates
};
/// The predicate indicates the type of the comparison to perform:
/// (un)orderedness, (in)equality and less/greater than (or equal to) as
/// well as predicates that are always true or false.

View File

@ -345,6 +345,24 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
let hasCanonicalizer = 1;
}
def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>;
def CMPI_P_NE : I64EnumAttrCase<"ne", 1>;
def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>;
def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>;
def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>;
def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>;
def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>;
def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>;
def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>;
def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>;
def CmpIPredicateAttr : I64EnumAttr<
"CmpIPredicate", "",
[CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT,
CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> {
let cppNamespace = "::mlir";
}
def CmpIOp : Std_Op<"cmpi",
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
let summary = "integer comparison operation";
@ -369,7 +387,11 @@ def CmpIOp : Std_Op<"cmpi",
%r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
}];
let arguments = (ins IntegerLike:$lhs, IntegerLike:$rhs);
let arguments = (ins
CmpIPredicateAttr:$predicate,
IntegerLike:$lhs,
IntegerLike:$rhs
);
let results = (outs BoolLike);
let builders = [OpBuilder<
@ -388,6 +410,8 @@ def CmpIOp : Std_Op<"cmpi",
}
}];
let verifier = [{ return success(); }];
let hasFolder = 1;
}

View File

@ -97,7 +97,7 @@ public:
Value *remainder = builder.create<RemISOp>(loc, lhs, rhs);
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
Value *isRemainderNegative =
builder.create<CmpIOp>(loc, CmpIPredicate::SLT, remainder, zeroCst);
builder.create<CmpIOp>(loc, CmpIPredicate::slt, remainder, zeroCst);
Value *correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
Value *result = builder.create<SelectOp>(loc, isRemainderNegative,
correctedRemainder, remainder);
@ -134,7 +134,7 @@ public:
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
Value *noneCst = builder.create<ConstantIndexOp>(loc, -1);
Value *negative =
builder.create<CmpIOp>(loc, CmpIPredicate::SLT, lhs, zeroCst);
builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, zeroCst);
Value *negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
Value *dividend =
builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
@ -173,7 +173,7 @@ public:
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
Value *oneCst = builder.create<ConstantIndexOp>(loc, 1);
Value *nonPositive =
builder.create<CmpIOp>(loc, CmpIPredicate::SLE, lhs, zeroCst);
builder.create<CmpIOp>(loc, CmpIPredicate::sle, lhs, zeroCst);
Value *negated = builder.create<SubIOp>(loc, zeroCst, lhs);
Value *decremented = builder.create<SubIOp>(loc, lhs, oneCst);
Value *dividend =
@ -277,7 +277,7 @@ Value *mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
boundOperands);
if (!lbValues)
return nullptr;
return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SGT, *lbValues,
return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::sgt, *lbValues,
builder);
}
@ -290,7 +290,7 @@ Value *mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
boundOperands);
if (!ubValues)
return nullptr;
return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SLT, *ubValues,
return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::slt, *ubValues,
builder);
}
@ -352,7 +352,7 @@ public:
operandsRef.drop_front(numDims));
if (!affResult)
return matchFailure();
auto pred = isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE;
auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge;
Value *cmpVal =
rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
cond =

View File

@ -205,7 +205,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
auto comparison =
rewriter.create<CmpIOp>(loc, CmpIPredicate::SLT, iv, upperBound);
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, iv, upperBound);
rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
ArrayRef<Value *>(), endBlock,

View File

@ -90,12 +90,12 @@ public:
cmpIOpOperands.rhs()); \
return matchSuccess();
DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp);
DISPATCH(CmpIPredicate::NE, spirv::INotEqualOp);
DISPATCH(CmpIPredicate::SLT, spirv::SLessThanOp);
DISPATCH(CmpIPredicate::SLE, spirv::SLessThanEqualOp);
DISPATCH(CmpIPredicate::SGT, spirv::SGreaterThanOp);
DISPATCH(CmpIPredicate::SGE, spirv::SGreaterThanEqualOp);
DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
#undef DISPATCH

View File

@ -35,6 +35,9 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/StandardOps/OpsEnums.cpp.inc"
using namespace mlir;
//===----------------------------------------------------------------------===//
@ -699,43 +702,6 @@ static Type getI1SameShape(Builder *build, Type type) {
// CmpIOp
//===----------------------------------------------------------------------===//
// Returns an array of mnemonics for CmpIPredicates indexed by values thereof.
static inline const char *const *getCmpIPredicateNames() {
static const char *predicateNames[]{
/*EQ*/ "eq",
/*NE*/ "ne",
/*SLT*/ "slt",
/*SLE*/ "sle",
/*SGT*/ "sgt",
/*SGE*/ "sge",
/*ULT*/ "ult",
/*ULE*/ "ule",
/*UGT*/ "ugt",
/*UGE*/ "uge",
};
static_assert(std::extent<decltype(predicateNames)>::value ==
(size_t)CmpIPredicate::NumPredicates,
"wrong number of predicate names");
return predicateNames;
}
// Returns a value of the predicate corresponding to the given mnemonic.
// Returns NumPredicates (one-past-end) if there is no such mnemonic.
CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
return llvm::StringSwitch<CmpIPredicate>(name)
.Case("eq", CmpIPredicate::EQ)
.Case("ne", CmpIPredicate::NE)
.Case("slt", CmpIPredicate::SLT)
.Case("sle", CmpIPredicate::SLE)
.Case("sgt", CmpIPredicate::SGT)
.Case("sge", CmpIPredicate::SGE)
.Case("ult", CmpIPredicate::ULT)
.Case("ule", CmpIPredicate::ULE)
.Case("ugt", CmpIPredicate::UGT)
.Case("uge", CmpIPredicate::UGE)
.Default(CmpIPredicate::NumPredicates);
}
static void buildCmpIOp(Builder *build, OperationState &result,
CmpIPredicate predicate, Value *lhs, Value *rhs) {
result.addOperands({lhs, rhs});
@ -763,8 +729,8 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
// Rewrite string attribute to an enum value.
StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
auto predicate = CmpIOp::getPredicateByName(predicateName);
if (predicate == CmpIPredicate::NumPredicates)
Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName);
if (!predicate.hasValue())
return parser.emitError(parser.getNameLoc())
<< "unknown comparison predicate \"" << predicateName << "\"";
@ -774,7 +740,7 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
return parser.emitError(parser.getNameLoc(),
"expected type with valid i1 shape");
attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate));
result.attributes = attrs;
result.addTypes({i1Type});
@ -784,15 +750,11 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, CmpIOp op) {
p << "cmpi ";
Builder b(op.getContext());
auto predicateValue =
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) &&
predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) &&
"unknown predicate index");
Builder b(op.getContext());
auto predicateStringAttr =
b.getStringAttr(getCmpIPredicateNames()[predicateValue]);
p.printAttribute(predicateStringAttr);
p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
<< '"';
p << ", ";
p.printOperand(op.lhs());
@ -803,43 +765,30 @@ static void print(OpAsmPrinter &p, CmpIOp op) {
p << " : " << op.lhs()->getType();
}
static LogicalResult verify(CmpIOp op) {
auto predicateAttr =
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName());
if (!predicateAttr)
return op.emitOpError("requires an integer attribute named 'predicate'");
auto predicate = predicateAttr.getInt();
if (predicate < (int64_t)CmpIPredicate::FirstValidValue ||
predicate >= (int64_t)CmpIPredicate::NumPredicates)
return op.emitOpError("'predicate' attribute value out of range");
return success();
}
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
// comparison predicates.
static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
const APInt &rhs) {
switch (predicate) {
case CmpIPredicate::EQ:
case CmpIPredicate::eq:
return lhs.eq(rhs);
case CmpIPredicate::NE:
case CmpIPredicate::ne:
return lhs.ne(rhs);
case CmpIPredicate::SLT:
case CmpIPredicate::slt:
return lhs.slt(rhs);
case CmpIPredicate::SLE:
case CmpIPredicate::sle:
return lhs.sle(rhs);
case CmpIPredicate::SGT:
case CmpIPredicate::sgt:
return lhs.sgt(rhs);
case CmpIPredicate::SGE:
case CmpIPredicate::sge:
return lhs.sge(rhs);
case CmpIPredicate::ULT:
case CmpIPredicate::ult:
return lhs.ult(rhs);
case CmpIPredicate::ULE:
case CmpIPredicate::ule:
return lhs.ule(rhs);
case CmpIPredicate::UGT:
case CmpIPredicate::ugt:
return lhs.ugt(rhs);
case CmpIPredicate::UGE:
case CmpIPredicate::uge:
return lhs.uge(rhs);
default:
llvm_unreachable("unknown comparison predicate");

View File

@ -460,13 +460,13 @@ ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::EQ, lhs, rhs);
: createIComparisonExpr(CmpIPredicate::eq, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::NE, lhs, rhs);
: createIComparisonExpr(CmpIPredicate::ne, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
@ -474,23 +474,23 @@ ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
:
// TODO(ntv,zinenko): signed by default, how about unsigned?
createIComparisonExpr(CmpIPredicate::SLT, lhs, rhs);
createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::SLE, lhs, rhs);
: createIComparisonExpr(CmpIPredicate::sle, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::SGT, lhs, rhs);
: createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::SGE, lhs, rhs);
: createIComparisonExpr(CmpIPredicate::sge, lhs, rhs);
}

View File

@ -747,7 +747,7 @@ static Loops stripmineSink(loop::ForOp forOp, Value *factor,
// Insert newForOp before the terminator of `t`.
OpBuilder b(t.getBodyBuilder());
Value *stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
Value *less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::SLT,
Value *less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::slt,
forOp.upperBound(), stepped);
Value *ub =
b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped);

View File

@ -201,7 +201,7 @@ func @func_with_ops(i32) {
func @func_with_ops(i32) {
^bb0(%a : i32):
// expected-error@+1 {{'predicate' attribute value out of range}}
// expected-error@+1 {{failed to satisfy constraint: allowed 64-bit integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}
%r = "std.cmpi"(%a, %a) {predicate = 42} : (i32, i32) -> i1
}
@ -241,7 +241,7 @@ func @func_with_ops(i32, i32) {
func @func_with_ops(i32, i32) {
^bb0(%a : i32, %b : i32):
// expected-error@+1 {{requires an integer attribute named 'predicate'}}
// expected-error@+1 {{requires attribute 'predicate'}}
%r = "std.cmpi"(%a, %b) {foo = 1} : (i32, i32) -> i1
}