forked from OSchip/llvm-project
[mlir][arith] Fix CmpIOP folding for vector types.
Previously, the folding assumed that it always operates on scalar types. Differential Revision: https://reviews.llvm.org/D116151
This commit is contained in:
parent
4639461531
commit
4a10457d33
|
@ -1056,13 +1056,21 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
|
|||
llvm_unreachable("unknown cmpi predicate kind");
|
||||
}
|
||||
|
||||
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
|
||||
auto boolAttr = BoolAttr::get(ctx, value);
|
||||
ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType)
|
||||
return boolAttr;
|
||||
return DenseElementsAttr::get(shapedType, boolAttr);
|
||||
}
|
||||
|
||||
OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "cmpi takes two operands");
|
||||
|
||||
// cmpi(pred, x, x)
|
||||
if (getLhs() == getRhs()) {
|
||||
auto val = applyCmpPredicateToEqualOperands(getPredicate());
|
||||
return BoolAttr::get(getContext(), val);
|
||||
return getBoolAttribute(getType(), getContext(), val);
|
||||
}
|
||||
|
||||
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
||||
|
|
|
@ -22,6 +22,32 @@ func @cmpi_equal_operands(%arg0: i64)
|
|||
: i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
|
||||
}
|
||||
|
||||
// Test case: Folding of comparisons with equal vector operands.
|
||||
// CHECK-LABEL: @cmpi_equal_vector_operands
|
||||
// CHECK-DAG: %[[T:.*]] = arith.constant dense<true>
|
||||
// CHECK-DAG: %[[F:.*]] = arith.constant dense<false>
|
||||
// CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]],
|
||||
// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]]
|
||||
func @cmpi_equal_vector_operands(%arg0: vector<1x8xi64>)
|
||||
-> (vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
|
||||
vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
|
||||
vector<1x8xi1>, vector<1x8xi1>) {
|
||||
%0 = arith.cmpi eq, %arg0, %arg0 : vector<1x8xi64>
|
||||
%1 = arith.cmpi sle, %arg0, %arg0 : vector<1x8xi64>
|
||||
%2 = arith.cmpi sge, %arg0, %arg0 : vector<1x8xi64>
|
||||
%3 = arith.cmpi ule, %arg0, %arg0 : vector<1x8xi64>
|
||||
%4 = arith.cmpi uge, %arg0, %arg0 : vector<1x8xi64>
|
||||
%5 = arith.cmpi ne, %arg0, %arg0 : vector<1x8xi64>
|
||||
%6 = arith.cmpi slt, %arg0, %arg0 : vector<1x8xi64>
|
||||
%7 = arith.cmpi sgt, %arg0, %arg0 : vector<1x8xi64>
|
||||
%8 = arith.cmpi ult, %arg0, %arg0 : vector<1x8xi64>
|
||||
%9 = arith.cmpi ugt, %arg0, %arg0 : vector<1x8xi64>
|
||||
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
|
||||
: vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
|
||||
vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
|
||||
vector<1x8xi1>, vector<1x8xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @indexCastOfSignExtend
|
||||
|
|
Loading…
Reference in New Issue