forked from OSchip/llvm-project
NFC: Convert CmpIPredicate in StandardOps to use EnumAttr
This turns several hand-written functions to auto-generated ones. PiperOrigin-RevId: 280684326
This commit is contained in:
parent
9d7039b001
commit
a0986bf43d
|
@ -870,37 +870,37 @@ PYBIND11_MODULE(pybind, m) {
|
||||||
.def("__lt__",
|
.def("__lt__",
|
||||||
[](PythonValueHandle lhs,
|
[](PythonValueHandle lhs,
|
||||||
PythonValueHandle rhs) -> PythonValueHandle {
|
PythonValueHandle rhs) -> PythonValueHandle {
|
||||||
return ValueHandle::create<CmpIOp>(CmpIPredicate::SLT, lhs.value,
|
return ValueHandle::create<CmpIOp>(CmpIPredicate::slt, lhs.value,
|
||||||
rhs.value);
|
rhs.value);
|
||||||
})
|
})
|
||||||
.def("__le__",
|
.def("__le__",
|
||||||
[](PythonValueHandle lhs,
|
[](PythonValueHandle lhs,
|
||||||
PythonValueHandle rhs) -> PythonValueHandle {
|
PythonValueHandle rhs) -> PythonValueHandle {
|
||||||
return ValueHandle::create<CmpIOp>(CmpIPredicate::SLE, lhs.value,
|
return ValueHandle::create<CmpIOp>(CmpIPredicate::sle, lhs.value,
|
||||||
rhs.value);
|
rhs.value);
|
||||||
})
|
})
|
||||||
.def("__gt__",
|
.def("__gt__",
|
||||||
[](PythonValueHandle lhs,
|
[](PythonValueHandle lhs,
|
||||||
PythonValueHandle rhs) -> PythonValueHandle {
|
PythonValueHandle rhs) -> PythonValueHandle {
|
||||||
return ValueHandle::create<CmpIOp>(CmpIPredicate::SGT, lhs.value,
|
return ValueHandle::create<CmpIOp>(CmpIPredicate::sgt, lhs.value,
|
||||||
rhs.value);
|
rhs.value);
|
||||||
})
|
})
|
||||||
.def("__ge__",
|
.def("__ge__",
|
||||||
[](PythonValueHandle lhs,
|
[](PythonValueHandle lhs,
|
||||||
PythonValueHandle rhs) -> PythonValueHandle {
|
PythonValueHandle rhs) -> PythonValueHandle {
|
||||||
return ValueHandle::create<CmpIOp>(CmpIPredicate::SGE, lhs.value,
|
return ValueHandle::create<CmpIOp>(CmpIPredicate::sge, lhs.value,
|
||||||
rhs.value);
|
rhs.value);
|
||||||
})
|
})
|
||||||
.def("__eq__",
|
.def("__eq__",
|
||||||
[](PythonValueHandle lhs,
|
[](PythonValueHandle lhs,
|
||||||
PythonValueHandle rhs) -> PythonValueHandle {
|
PythonValueHandle rhs) -> PythonValueHandle {
|
||||||
return ValueHandle::create<CmpIOp>(CmpIPredicate::EQ, lhs.value,
|
return ValueHandle::create<CmpIOp>(CmpIPredicate::eq, lhs.value,
|
||||||
rhs.value);
|
rhs.value);
|
||||||
})
|
})
|
||||||
.def("__ne__",
|
.def("__ne__",
|
||||||
[](PythonValueHandle lhs,
|
[](PythonValueHandle lhs,
|
||||||
PythonValueHandle rhs) -> PythonValueHandle {
|
PythonValueHandle rhs) -> PythonValueHandle {
|
||||||
return ValueHandle::create<CmpIOp>(CmpIPredicate::NE, lhs.value,
|
return ValueHandle::create<CmpIOp>(CmpIPredicate::ne, lhs.value,
|
||||||
rhs.value);
|
rhs.value);
|
||||||
})
|
})
|
||||||
.def("__invert__",
|
.def("__invert__",
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
set(LLVM_TARGET_DEFINITIONS Ops.td)
|
set(LLVM_TARGET_DEFINITIONS Ops.td)
|
||||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||||
|
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||||
|
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||||
add_public_tablegen_target(MLIRStandardOpsIncGen)
|
add_public_tablegen_target(MLIRStandardOpsIncGen)
|
||||||
|
|
|
@ -30,6 +30,9 @@
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
// Pull in all enum type definitions and utility function declarations.
|
||||||
|
#include "mlir/Dialect/StandardOps/OpsEnums.h.inc"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class AffineMap;
|
class AffineMap;
|
||||||
class Builder;
|
class Builder;
|
||||||
|
@ -42,27 +45,6 @@ public:
|
||||||
static StringRef getDialectNamespace() { return "std"; }
|
static StringRef getDialectNamespace() { return "std"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
/// The predicate indicates the type of the comparison to perform:
|
|
||||||
/// (in)equality; (un)signed less/greater than (or equal to).
|
|
||||||
enum class CmpIPredicate {
|
|
||||||
FirstValidValue,
|
|
||||||
// (In)equality comparisons.
|
|
||||||
EQ = FirstValidValue,
|
|
||||||
NE,
|
|
||||||
// Signed comparisons.
|
|
||||||
SLT,
|
|
||||||
SLE,
|
|
||||||
SGT,
|
|
||||||
SGE,
|
|
||||||
// Unsigned comparisons.
|
|
||||||
ULT,
|
|
||||||
ULE,
|
|
||||||
UGT,
|
|
||||||
UGE,
|
|
||||||
// Number of predicates.
|
|
||||||
NumPredicates
|
|
||||||
};
|
|
||||||
|
|
||||||
/// The predicate indicates the type of the comparison to perform:
|
/// The predicate indicates the type of the comparison to perform:
|
||||||
/// (un)orderedness, (in)equality and less/greater than (or equal to) as
|
/// (un)orderedness, (in)equality and less/greater than (or equal to) as
|
||||||
/// well as predicates that are always true or false.
|
/// well as predicates that are always true or false.
|
||||||
|
|
|
@ -345,6 +345,24 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>;
|
||||||
|
def CMPI_P_NE : I64EnumAttrCase<"ne", 1>;
|
||||||
|
def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>;
|
||||||
|
def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>;
|
||||||
|
def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>;
|
||||||
|
def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>;
|
||||||
|
def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>;
|
||||||
|
def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>;
|
||||||
|
def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>;
|
||||||
|
def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>;
|
||||||
|
|
||||||
|
def CmpIPredicateAttr : I64EnumAttr<
|
||||||
|
"CmpIPredicate", "",
|
||||||
|
[CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT,
|
||||||
|
CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> {
|
||||||
|
let cppNamespace = "::mlir";
|
||||||
|
}
|
||||||
|
|
||||||
def CmpIOp : Std_Op<"cmpi",
|
def CmpIOp : Std_Op<"cmpi",
|
||||||
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
|
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
|
||||||
let summary = "integer comparison operation";
|
let summary = "integer comparison operation";
|
||||||
|
@ -369,7 +387,11 @@ def CmpIOp : Std_Op<"cmpi",
|
||||||
%r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
|
%r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins IntegerLike:$lhs, IntegerLike:$rhs);
|
let arguments = (ins
|
||||||
|
CmpIPredicateAttr:$predicate,
|
||||||
|
IntegerLike:$lhs,
|
||||||
|
IntegerLike:$rhs
|
||||||
|
);
|
||||||
let results = (outs BoolLike);
|
let results = (outs BoolLike);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
|
@ -388,6 +410,8 @@ def CmpIOp : Std_Op<"cmpi",
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let verifier = [{ return success(); }];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -97,7 +97,7 @@ public:
|
||||||
Value *remainder = builder.create<RemISOp>(loc, lhs, rhs);
|
Value *remainder = builder.create<RemISOp>(loc, lhs, rhs);
|
||||||
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
|
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
|
||||||
Value *isRemainderNegative =
|
Value *isRemainderNegative =
|
||||||
builder.create<CmpIOp>(loc, CmpIPredicate::SLT, remainder, zeroCst);
|
builder.create<CmpIOp>(loc, CmpIPredicate::slt, remainder, zeroCst);
|
||||||
Value *correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
|
Value *correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
|
||||||
Value *result = builder.create<SelectOp>(loc, isRemainderNegative,
|
Value *result = builder.create<SelectOp>(loc, isRemainderNegative,
|
||||||
correctedRemainder, remainder);
|
correctedRemainder, remainder);
|
||||||
|
@ -134,7 +134,7 @@ public:
|
||||||
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
|
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
|
||||||
Value *noneCst = builder.create<ConstantIndexOp>(loc, -1);
|
Value *noneCst = builder.create<ConstantIndexOp>(loc, -1);
|
||||||
Value *negative =
|
Value *negative =
|
||||||
builder.create<CmpIOp>(loc, CmpIPredicate::SLT, lhs, zeroCst);
|
builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, zeroCst);
|
||||||
Value *negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
|
Value *negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
|
||||||
Value *dividend =
|
Value *dividend =
|
||||||
builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
|
builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
|
||||||
|
@ -173,7 +173,7 @@ public:
|
||||||
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
|
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
|
||||||
Value *oneCst = builder.create<ConstantIndexOp>(loc, 1);
|
Value *oneCst = builder.create<ConstantIndexOp>(loc, 1);
|
||||||
Value *nonPositive =
|
Value *nonPositive =
|
||||||
builder.create<CmpIOp>(loc, CmpIPredicate::SLE, lhs, zeroCst);
|
builder.create<CmpIOp>(loc, CmpIPredicate::sle, lhs, zeroCst);
|
||||||
Value *negated = builder.create<SubIOp>(loc, zeroCst, lhs);
|
Value *negated = builder.create<SubIOp>(loc, zeroCst, lhs);
|
||||||
Value *decremented = builder.create<SubIOp>(loc, lhs, oneCst);
|
Value *decremented = builder.create<SubIOp>(loc, lhs, oneCst);
|
||||||
Value *dividend =
|
Value *dividend =
|
||||||
|
@ -277,7 +277,7 @@ Value *mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
|
||||||
boundOperands);
|
boundOperands);
|
||||||
if (!lbValues)
|
if (!lbValues)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SGT, *lbValues,
|
return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::sgt, *lbValues,
|
||||||
builder);
|
builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -290,7 +290,7 @@ Value *mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
|
||||||
boundOperands);
|
boundOperands);
|
||||||
if (!ubValues)
|
if (!ubValues)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SLT, *ubValues,
|
return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::slt, *ubValues,
|
||||||
builder);
|
builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -352,7 +352,7 @@ public:
|
||||||
operandsRef.drop_front(numDims));
|
operandsRef.drop_front(numDims));
|
||||||
if (!affResult)
|
if (!affResult)
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
auto pred = isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE;
|
auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge;
|
||||||
Value *cmpVal =
|
Value *cmpVal =
|
||||||
rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
|
rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
|
||||||
cond =
|
cond =
|
||||||
|
|
|
@ -205,7 +205,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
|
||||||
// With the body block done, we can fill in the condition block.
|
// With the body block done, we can fill in the condition block.
|
||||||
rewriter.setInsertionPointToEnd(conditionBlock);
|
rewriter.setInsertionPointToEnd(conditionBlock);
|
||||||
auto comparison =
|
auto comparison =
|
||||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::SLT, iv, upperBound);
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, iv, upperBound);
|
||||||
|
|
||||||
rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
|
rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
|
||||||
ArrayRef<Value *>(), endBlock,
|
ArrayRef<Value *>(), endBlock,
|
||||||
|
|
|
@ -90,12 +90,12 @@ public:
|
||||||
cmpIOpOperands.rhs()); \
|
cmpIOpOperands.rhs()); \
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
|
|
||||||
DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp);
|
DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
|
||||||
DISPATCH(CmpIPredicate::NE, spirv::INotEqualOp);
|
DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
|
||||||
DISPATCH(CmpIPredicate::SLT, spirv::SLessThanOp);
|
DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
|
||||||
DISPATCH(CmpIPredicate::SLE, spirv::SLessThanEqualOp);
|
DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
|
||||||
DISPATCH(CmpIPredicate::SGT, spirv::SGreaterThanOp);
|
DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
|
||||||
DISPATCH(CmpIPredicate::SGE, spirv::SGreaterThanEqualOp);
|
DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
|
||||||
|
|
||||||
#undef DISPATCH
|
#undef DISPATCH
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,9 @@
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
// Pull in all enum type definitions and utility function declarations.
|
||||||
|
#include "mlir/Dialect/StandardOps/OpsEnums.cpp.inc"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -699,43 +702,6 @@ static Type getI1SameShape(Builder *build, Type type) {
|
||||||
// CmpIOp
|
// CmpIOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Returns an array of mnemonics for CmpIPredicates indexed by values thereof.
|
|
||||||
static inline const char *const *getCmpIPredicateNames() {
|
|
||||||
static const char *predicateNames[]{
|
|
||||||
/*EQ*/ "eq",
|
|
||||||
/*NE*/ "ne",
|
|
||||||
/*SLT*/ "slt",
|
|
||||||
/*SLE*/ "sle",
|
|
||||||
/*SGT*/ "sgt",
|
|
||||||
/*SGE*/ "sge",
|
|
||||||
/*ULT*/ "ult",
|
|
||||||
/*ULE*/ "ule",
|
|
||||||
/*UGT*/ "ugt",
|
|
||||||
/*UGE*/ "uge",
|
|
||||||
};
|
|
||||||
static_assert(std::extent<decltype(predicateNames)>::value ==
|
|
||||||
(size_t)CmpIPredicate::NumPredicates,
|
|
||||||
"wrong number of predicate names");
|
|
||||||
return predicateNames;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a value of the predicate corresponding to the given mnemonic.
|
|
||||||
// Returns NumPredicates (one-past-end) if there is no such mnemonic.
|
|
||||||
CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
|
|
||||||
return llvm::StringSwitch<CmpIPredicate>(name)
|
|
||||||
.Case("eq", CmpIPredicate::EQ)
|
|
||||||
.Case("ne", CmpIPredicate::NE)
|
|
||||||
.Case("slt", CmpIPredicate::SLT)
|
|
||||||
.Case("sle", CmpIPredicate::SLE)
|
|
||||||
.Case("sgt", CmpIPredicate::SGT)
|
|
||||||
.Case("sge", CmpIPredicate::SGE)
|
|
||||||
.Case("ult", CmpIPredicate::ULT)
|
|
||||||
.Case("ule", CmpIPredicate::ULE)
|
|
||||||
.Case("ugt", CmpIPredicate::UGT)
|
|
||||||
.Case("uge", CmpIPredicate::UGE)
|
|
||||||
.Default(CmpIPredicate::NumPredicates);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void buildCmpIOp(Builder *build, OperationState &result,
|
static void buildCmpIOp(Builder *build, OperationState &result,
|
||||||
CmpIPredicate predicate, Value *lhs, Value *rhs) {
|
CmpIPredicate predicate, Value *lhs, Value *rhs) {
|
||||||
result.addOperands({lhs, rhs});
|
result.addOperands({lhs, rhs});
|
||||||
|
@ -763,8 +729,8 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
|
||||||
|
|
||||||
// Rewrite string attribute to an enum value.
|
// Rewrite string attribute to an enum value.
|
||||||
StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
|
StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
|
||||||
auto predicate = CmpIOp::getPredicateByName(predicateName);
|
Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName);
|
||||||
if (predicate == CmpIPredicate::NumPredicates)
|
if (!predicate.hasValue())
|
||||||
return parser.emitError(parser.getNameLoc())
|
return parser.emitError(parser.getNameLoc())
|
||||||
<< "unknown comparison predicate \"" << predicateName << "\"";
|
<< "unknown comparison predicate \"" << predicateName << "\"";
|
||||||
|
|
||||||
|
@ -774,7 +740,7 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
|
||||||
return parser.emitError(parser.getNameLoc(),
|
return parser.emitError(parser.getNameLoc(),
|
||||||
"expected type with valid i1 shape");
|
"expected type with valid i1 shape");
|
||||||
|
|
||||||
attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
|
attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate));
|
||||||
result.attributes = attrs;
|
result.attributes = attrs;
|
||||||
|
|
||||||
result.addTypes({i1Type});
|
result.addTypes({i1Type});
|
||||||
|
@ -784,15 +750,11 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
|
||||||
static void print(OpAsmPrinter &p, CmpIOp op) {
|
static void print(OpAsmPrinter &p, CmpIOp op) {
|
||||||
p << "cmpi ";
|
p << "cmpi ";
|
||||||
|
|
||||||
|
Builder b(op.getContext());
|
||||||
auto predicateValue =
|
auto predicateValue =
|
||||||
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
|
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
|
||||||
assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) &&
|
p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
|
||||||
predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) &&
|
<< '"';
|
||||||
"unknown predicate index");
|
|
||||||
Builder b(op.getContext());
|
|
||||||
auto predicateStringAttr =
|
|
||||||
b.getStringAttr(getCmpIPredicateNames()[predicateValue]);
|
|
||||||
p.printAttribute(predicateStringAttr);
|
|
||||||
|
|
||||||
p << ", ";
|
p << ", ";
|
||||||
p.printOperand(op.lhs());
|
p.printOperand(op.lhs());
|
||||||
|
@ -803,43 +765,30 @@ static void print(OpAsmPrinter &p, CmpIOp op) {
|
||||||
p << " : " << op.lhs()->getType();
|
p << " : " << op.lhs()->getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verify(CmpIOp op) {
|
|
||||||
auto predicateAttr =
|
|
||||||
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName());
|
|
||||||
if (!predicateAttr)
|
|
||||||
return op.emitOpError("requires an integer attribute named 'predicate'");
|
|
||||||
auto predicate = predicateAttr.getInt();
|
|
||||||
if (predicate < (int64_t)CmpIPredicate::FirstValidValue ||
|
|
||||||
predicate >= (int64_t)CmpIPredicate::NumPredicates)
|
|
||||||
return op.emitOpError("'predicate' attribute value out of range");
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
|
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
|
||||||
// comparison predicates.
|
// comparison predicates.
|
||||||
static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
|
static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
|
||||||
const APInt &rhs) {
|
const APInt &rhs) {
|
||||||
switch (predicate) {
|
switch (predicate) {
|
||||||
case CmpIPredicate::EQ:
|
case CmpIPredicate::eq:
|
||||||
return lhs.eq(rhs);
|
return lhs.eq(rhs);
|
||||||
case CmpIPredicate::NE:
|
case CmpIPredicate::ne:
|
||||||
return lhs.ne(rhs);
|
return lhs.ne(rhs);
|
||||||
case CmpIPredicate::SLT:
|
case CmpIPredicate::slt:
|
||||||
return lhs.slt(rhs);
|
return lhs.slt(rhs);
|
||||||
case CmpIPredicate::SLE:
|
case CmpIPredicate::sle:
|
||||||
return lhs.sle(rhs);
|
return lhs.sle(rhs);
|
||||||
case CmpIPredicate::SGT:
|
case CmpIPredicate::sgt:
|
||||||
return lhs.sgt(rhs);
|
return lhs.sgt(rhs);
|
||||||
case CmpIPredicate::SGE:
|
case CmpIPredicate::sge:
|
||||||
return lhs.sge(rhs);
|
return lhs.sge(rhs);
|
||||||
case CmpIPredicate::ULT:
|
case CmpIPredicate::ult:
|
||||||
return lhs.ult(rhs);
|
return lhs.ult(rhs);
|
||||||
case CmpIPredicate::ULE:
|
case CmpIPredicate::ule:
|
||||||
return lhs.ule(rhs);
|
return lhs.ule(rhs);
|
||||||
case CmpIPredicate::UGT:
|
case CmpIPredicate::ugt:
|
||||||
return lhs.ugt(rhs);
|
return lhs.ugt(rhs);
|
||||||
case CmpIPredicate::UGE:
|
case CmpIPredicate::uge:
|
||||||
return lhs.uge(rhs);
|
return lhs.uge(rhs);
|
||||||
default:
|
default:
|
||||||
llvm_unreachable("unknown comparison predicate");
|
llvm_unreachable("unknown comparison predicate");
|
||||||
|
|
|
@ -460,13 +460,13 @@ ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) {
|
||||||
auto type = lhs.getType();
|
auto type = lhs.getType();
|
||||||
return type.isa<FloatType>()
|
return type.isa<FloatType>()
|
||||||
? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
|
? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
|
||||||
: createIComparisonExpr(CmpIPredicate::EQ, lhs, rhs);
|
: createIComparisonExpr(CmpIPredicate::eq, lhs, rhs);
|
||||||
}
|
}
|
||||||
ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) {
|
ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) {
|
||||||
auto type = lhs.getType();
|
auto type = lhs.getType();
|
||||||
return type.isa<FloatType>()
|
return type.isa<FloatType>()
|
||||||
? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
|
? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
|
||||||
: createIComparisonExpr(CmpIPredicate::NE, lhs, rhs);
|
: createIComparisonExpr(CmpIPredicate::ne, lhs, rhs);
|
||||||
}
|
}
|
||||||
ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
|
ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
|
||||||
auto type = lhs.getType();
|
auto type = lhs.getType();
|
||||||
|
@ -474,23 +474,23 @@ ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
|
||||||
? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
|
? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
|
||||||
:
|
:
|
||||||
// TODO(ntv,zinenko): signed by default, how about unsigned?
|
// TODO(ntv,zinenko): signed by default, how about unsigned?
|
||||||
createIComparisonExpr(CmpIPredicate::SLT, lhs, rhs);
|
createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
|
||||||
}
|
}
|
||||||
ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) {
|
ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) {
|
||||||
auto type = lhs.getType();
|
auto type = lhs.getType();
|
||||||
return type.isa<FloatType>()
|
return type.isa<FloatType>()
|
||||||
? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
|
? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
|
||||||
: createIComparisonExpr(CmpIPredicate::SLE, lhs, rhs);
|
: createIComparisonExpr(CmpIPredicate::sle, lhs, rhs);
|
||||||
}
|
}
|
||||||
ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) {
|
ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) {
|
||||||
auto type = lhs.getType();
|
auto type = lhs.getType();
|
||||||
return type.isa<FloatType>()
|
return type.isa<FloatType>()
|
||||||
? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
|
? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
|
||||||
: createIComparisonExpr(CmpIPredicate::SGT, lhs, rhs);
|
: createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs);
|
||||||
}
|
}
|
||||||
ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) {
|
ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) {
|
||||||
auto type = lhs.getType();
|
auto type = lhs.getType();
|
||||||
return type.isa<FloatType>()
|
return type.isa<FloatType>()
|
||||||
? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
|
? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
|
||||||
: createIComparisonExpr(CmpIPredicate::SGE, lhs, rhs);
|
: createIComparisonExpr(CmpIPredicate::sge, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -747,7 +747,7 @@ static Loops stripmineSink(loop::ForOp forOp, Value *factor,
|
||||||
// Insert newForOp before the terminator of `t`.
|
// Insert newForOp before the terminator of `t`.
|
||||||
OpBuilder b(t.getBodyBuilder());
|
OpBuilder b(t.getBodyBuilder());
|
||||||
Value *stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
|
Value *stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
|
||||||
Value *less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::SLT,
|
Value *less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::slt,
|
||||||
forOp.upperBound(), stepped);
|
forOp.upperBound(), stepped);
|
||||||
Value *ub =
|
Value *ub =
|
||||||
b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped);
|
b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped);
|
||||||
|
|
|
@ -201,7 +201,7 @@ func @func_with_ops(i32) {
|
||||||
|
|
||||||
func @func_with_ops(i32) {
|
func @func_with_ops(i32) {
|
||||||
^bb0(%a : i32):
|
^bb0(%a : i32):
|
||||||
// expected-error@+1 {{'predicate' attribute value out of range}}
|
// expected-error@+1 {{failed to satisfy constraint: allowed 64-bit integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}
|
||||||
%r = "std.cmpi"(%a, %a) {predicate = 42} : (i32, i32) -> i1
|
%r = "std.cmpi"(%a, %a) {predicate = 42} : (i32, i32) -> i1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,7 +241,7 @@ func @func_with_ops(i32, i32) {
|
||||||
|
|
||||||
func @func_with_ops(i32, i32) {
|
func @func_with_ops(i32, i32) {
|
||||||
^bb0(%a : i32, %b : i32):
|
^bb0(%a : i32, %b : i32):
|
||||||
// expected-error@+1 {{requires an integer attribute named 'predicate'}}
|
// expected-error@+1 {{requires attribute 'predicate'}}
|
||||||
%r = "std.cmpi"(%a, %b) {foo = 1} : (i32, i32) -> i1
|
%r = "std.cmpi"(%a, %b) {foo = 1} : (i32, i32) -> i1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue