Add simple constant folding hook for CmpIOp

Integer comparisons can be constant folded if both of their arguments are known
constants, which we can compare in the compiler.  This requires implementing
all comparison predicates, but thanks to consistency between LLVM and MLIR
comparison predicates, we have a one-to-one correspondence between predicates
and llvm::APInt comparison functions.  Constant folding of comparsions with
maximum/minimum values of the integer type are left for future work.

This will be used to test the lowering of mod/floordiv/ceildiv in affine
expressions at compile time.

PiperOrigin-RevId: 228077580
This commit is contained in:
Alex Zinenko 2019-01-06 14:09:15 -08:00 committed by jpienaar
parent caa7e70627
commit c47ed53211
3 changed files with 72 additions and 0 deletions

View File

@ -199,6 +199,8 @@ public:
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
private:
friend class OperationInst;

View File

@ -571,6 +571,50 @@ bool CmpIOp::verify() const {
return false;
}
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
// comparison predicates.
static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
const APInt &rhs) {
switch (predicate) {
case CmpIPredicate::EQ:
return lhs.eq(rhs);
case CmpIPredicate::NE:
return lhs.ne(rhs);
case CmpIPredicate::SLT:
return lhs.slt(rhs);
case CmpIPredicate::SLE:
return lhs.sle(rhs);
case CmpIPredicate::SGT:
return lhs.sgt(rhs);
case CmpIPredicate::SGE:
return lhs.sge(rhs);
case CmpIPredicate::ULT:
return lhs.ult(rhs);
case CmpIPredicate::ULE:
return lhs.ule(rhs);
case CmpIPredicate::UGT:
return lhs.ugt(rhs);
case CmpIPredicate::UGE:
return lhs.uge(rhs);
default:
llvm_unreachable("unknown comparison predicate");
}
}
// Constant folding hook for comparisons.
Attribute CmpIOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.size() == 2 && "cmpi takes two arguments");
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
}
//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//

View File

@ -227,3 +227,29 @@ func @dim(%x : tensor<8x4xf32>) -> index {
return %0 : index
}
// CHECK-LABEL: func @cmpi
func @cmpi() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
%c42 = constant 42 : i32
%cm1 = constant -1 : i32
// CHECK-NEXT: %false = constant 0 : i1
%0 = cmpi "eq", %c42, %cm1 : i32
// CHECK-NEXT: %true = constant 1 : i1
%1 = cmpi "ne", %c42, %cm1 : i32
// CHECK-NEXT: %false_0 = constant 0 : i1
%2 = cmpi "slt", %c42, %cm1 : i32
// CHECK-NEXT: %false_1 = constant 0 : i1
%3 = cmpi "sle", %c42, %cm1 : i32
// CHECK-NEXT: %true_2 = constant 1 : i1
%4 = cmpi "sgt", %c42, %cm1 : i32
// CHECK-NEXT: %true_3 = constant 1 : i1
%5 = cmpi "sge", %c42, %cm1 : i32
// CHECK-NEXT: %true_4 = constant 1 : i1
%6 = cmpi "ult", %c42, %cm1 : i32
// CHECK-NEXT: %true_5 = constant 1 : i1
%7 = cmpi "ule", %c42, %cm1 : i32
// CHECK-NEXT: %false_6 = constant 0 : i1
%8 = cmpi "ugt", %c42, %cm1 : i32
// CHECK-NEXT: %false_7 = constant 0 : i1
%9 = cmpi "uge", %c42, %cm1 : i32
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
}