[mlir][arith] Fix CmpIOP folding for vector types.

Previously, the folding assumed that it always operates on scalar types.

Differential Revision: https://reviews.llvm.org/D116151
This commit is contained in:
Adrian Kuegel 2021-12-22 18:09:59 +01:00
parent 4639461531
commit 4a10457d33
2 changed files with 35 additions and 1 deletions

View File

@ -1056,13 +1056,21 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
llvm_unreachable("unknown cmpi predicate kind");
}
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
auto boolAttr = BoolAttr::get(ctx, value);
ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
if (!shapedType)
return boolAttr;
return DenseElementsAttr::get(shapedType, boolAttr);
}
OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpi takes two operands");
// cmpi(pred, x, x)
if (getLhs() == getRhs()) {
auto val = applyCmpPredicateToEqualOperands(getPredicate());
return BoolAttr::get(getContext(), val);
return getBoolAttribute(getType(), getContext(), val);
}
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();

View File

@ -22,6 +22,32 @@ func @cmpi_equal_operands(%arg0: i64)
: i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
}
// Test case: Folding of comparisons with equal vector operands.
// CHECK-LABEL: @cmpi_equal_vector_operands
// CHECK-DAG: %[[T:.*]] = arith.constant dense<true>
// CHECK-DAG: %[[F:.*]] = arith.constant dense<false>
// CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]],
// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]]
func @cmpi_equal_vector_operands(%arg0: vector<1x8xi64>)
-> (vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
vector<1x8xi1>, vector<1x8xi1>) {
%0 = arith.cmpi eq, %arg0, %arg0 : vector<1x8xi64>
%1 = arith.cmpi sle, %arg0, %arg0 : vector<1x8xi64>
%2 = arith.cmpi sge, %arg0, %arg0 : vector<1x8xi64>
%3 = arith.cmpi ule, %arg0, %arg0 : vector<1x8xi64>
%4 = arith.cmpi uge, %arg0, %arg0 : vector<1x8xi64>
%5 = arith.cmpi ne, %arg0, %arg0 : vector<1x8xi64>
%6 = arith.cmpi slt, %arg0, %arg0 : vector<1x8xi64>
%7 = arith.cmpi sgt, %arg0, %arg0 : vector<1x8xi64>
%8 = arith.cmpi ult, %arg0, %arg0 : vector<1x8xi64>
%9 = arith.cmpi ugt, %arg0, %arg0 : vector<1x8xi64>
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
: vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
vector<1x8xi1>, vector<1x8xi1>
}
// -----
// CHECK-LABEL: @indexCastOfSignExtend