forked from OSchip/llvm-project
[mlir][Arith] Fix up integer range inference for truncation
Attempting to apply the range analysis to real code revealed that trunci wasn't correctly handling the case where truncation would create wider ranges - for example, if we truncate [255, 257] : i16 to i8, the result can be 255, 0, or 1, which isn't a contiguous range of values. The previous implementation would naively map this to [255, 1], which would cause issues with unsigned ranges and unification. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D130501
This commit is contained in:
parent
7fc52d7c8b
commit
938fe9f277
|
@ -503,10 +503,37 @@ void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||||
static ConstantIntRanges truncIRange(const ConstantIntRanges &range,
|
static ConstantIntRanges truncIRange(const ConstantIntRanges &range,
|
||||||
Type destType) {
|
Type destType) {
|
||||||
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
|
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
|
||||||
APInt umin = range.umin().trunc(destWidth);
|
// If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
|
||||||
APInt umax = range.umax().trunc(destWidth);
|
// the range of the resulting value is not contiguous ind includes 0.
|
||||||
APInt smin = range.smin().trunc(destWidth);
|
// Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
|
||||||
APInt smax = range.smax().trunc(destWidth);
|
// but you can't truncate [255, 257] similarly.
|
||||||
|
bool hasUnsignedRollover =
|
||||||
|
range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
|
||||||
|
APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
|
||||||
|
: range.umin().trunc(destWidth);
|
||||||
|
APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
|
||||||
|
: range.umax().trunc(destWidth);
|
||||||
|
|
||||||
|
// Signed post-truncation rollover will not occur when either:
|
||||||
|
// - The high parts of the min and max, plus the sign bit, are the same
|
||||||
|
// - The high halves + sign bit of the min and max are either all 1s or all 0s
|
||||||
|
// and you won't create a [positive, negative] range by truncating.
|
||||||
|
// For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
|
||||||
|
// but not [255, 257]_i16 to a range of i8s. You can also truncate
|
||||||
|
// [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
|
||||||
|
// You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
|
||||||
|
// will truncate to 0x7e, which is greater than 0
|
||||||
|
APInt sminHighPart = range.smin().ashr(destWidth - 1);
|
||||||
|
APInt smaxHighPart = range.smax().ashr(destWidth - 1);
|
||||||
|
bool hasSignedOverflow =
|
||||||
|
(sminHighPart != smaxHighPart) &&
|
||||||
|
!(sminHighPart.isAllOnes() &&
|
||||||
|
(smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
|
||||||
|
!(sminHighPart.isZero() && smaxHighPart.isZero());
|
||||||
|
APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
|
||||||
|
: range.smin().trunc(destWidth);
|
||||||
|
APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
|
||||||
|
: range.smax().trunc(destWidth);
|
||||||
return {umin, umax, smin, smax};
|
return {umin, umax, smin, smax};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -463,14 +463,15 @@ func.func @trunci(%arg0 : i32) -> i1 {
|
||||||
%c-14_i16 = arith.constant -14 : i16
|
%c-14_i16 = arith.constant -14 : i16
|
||||||
%ci16_smin = arith.constant 0xffff8000 : i32
|
%ci16_smin = arith.constant 0xffff8000 : i32
|
||||||
%0 = arith.minsi %arg0, %c-14_i32 : i32
|
%0 = arith.minsi %arg0, %c-14_i32 : i32
|
||||||
%1 = arith.trunci %0 : i32 to i16
|
%1 = arith.maxsi %0, %ci16_smin : i32
|
||||||
%2 = arith.cmpi sle, %1, %c-14_i16 : i16
|
%2 = arith.trunci %1 : i32 to i16
|
||||||
%3 = arith.extsi %1 : i16 to i32
|
%3 = arith.cmpi sle, %2, %c-14_i16 : i16
|
||||||
%4 = arith.cmpi sle, %3, %c-14_i32 : i32
|
%4 = arith.extsi %2 : i16 to i32
|
||||||
%5 = arith.cmpi sge, %3, %ci16_smin : i32
|
%5 = arith.cmpi sle, %4, %c-14_i32 : i32
|
||||||
%6 = arith.andi %2, %4 : i1
|
%6 = arith.cmpi sge, %4, %ci16_smin : i32
|
||||||
%7 = arith.andi %6, %5 : i1
|
%7 = arith.andi %3, %5 : i1
|
||||||
func.return %7 : i1
|
%8 = arith.andi %7, %6 : i1
|
||||||
|
func.return %8 : i1
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @index_cast
|
// CHECK-LABEL: func @index_cast
|
||||||
|
@ -645,3 +646,69 @@ func.func @loop_bound_not_inferred_with_branch(%arg0 : index, %arg1 : i1) -> i1
|
||||||
func.return %8 : i1
|
func.return %8 : i1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test fon a bug where the noive implementation of trunctation led to the cast
|
||||||
|
// value being set to [0, 0].
|
||||||
|
// CHECK-LABEL: func.func @truncation_spillover
|
||||||
|
// CHECK: %[[unreplaced:.*]] = arith.index_cast
|
||||||
|
// CHECK: memref.store %[[unreplaced]]
|
||||||
|
func.func @truncation_spillover(%arg0 : memref<?xi32>) -> index {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c2 = arith.constant 2 : index
|
||||||
|
%c49 = arith.constant 49 : index
|
||||||
|
%0 = scf.for %arg1 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
|
||||||
|
%1 = arith.divsi %arg2, %c49 : index
|
||||||
|
%2 = arith.index_cast %1 : index to i32
|
||||||
|
memref.store %2, %arg0[%c0] : memref<?xi32>
|
||||||
|
%3 = arith.addi %arg2, %arg1 : index
|
||||||
|
scf.yield %3 : index
|
||||||
|
}
|
||||||
|
func.return %0 : index
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @trunc_catches_overflow
|
||||||
|
// CHECK: %[[sge:.*]] = arith.cmpi sge
|
||||||
|
// CHECK: return %[[sge]]
|
||||||
|
func.func @trunc_catches_overflow(%arg0 : i16) -> i1 {
|
||||||
|
%c0_i16 = arith.constant 0 : i16
|
||||||
|
%c130_i16 = arith.constant 130 : i16
|
||||||
|
%c0_i8 = arith.constant 0 : i8
|
||||||
|
%0 = arith.maxui %arg0, %c0_i16 : i16
|
||||||
|
%1 = arith.minui %0, %c130_i16 : i16
|
||||||
|
%2 = arith.trunci %1 : i16 to i8
|
||||||
|
%3 = arith.cmpi sge, %2, %c0_i8 : i8
|
||||||
|
%4 = arith.cmpi uge, %2, %c0_i8 : i8
|
||||||
|
%5 = arith.andi %3, %4 : i1
|
||||||
|
func.return %5 : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @trunc_respects_same_high_half
|
||||||
|
// CHECK: %[[false:.*]] = arith.constant false
|
||||||
|
// CHECK: return %[[false]]
|
||||||
|
func.func @trunc_respects_same_high_half(%arg0 : i16) -> i1 {
|
||||||
|
%c256_i16 = arith.constant 256 : i16
|
||||||
|
%c257_i16 = arith.constant 257 : i16
|
||||||
|
%c2_i8 = arith.constant 2 : i8
|
||||||
|
%0 = arith.maxui %arg0, %c256_i16 : i16
|
||||||
|
%1 = arith.minui %0, %c257_i16 : i16
|
||||||
|
%2 = arith.trunci %1 : i16 to i8
|
||||||
|
%3 = arith.cmpi sge, %2, %c2_i8 : i8
|
||||||
|
func.return %3 : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @trunc_handles_small_signed_ranges
|
||||||
|
// CHECK: %[[true:.*]] = arith.constant true
|
||||||
|
// CHECK: return %[[true]]
|
||||||
|
func.func @trunc_handles_small_signed_ranges(%arg0 : i16) -> i1 {
|
||||||
|
%c-2_i16 = arith.constant -2 : i16
|
||||||
|
%c2_i16 = arith.constant 2 : i16
|
||||||
|
%c-2_i8 = arith.constant -2 : i8
|
||||||
|
%c2_i8 = arith.constant 2 : i8
|
||||||
|
%0 = arith.maxsi %arg0, %c-2_i16 : i16
|
||||||
|
%1 = arith.minsi %0, %c2_i16 : i16
|
||||||
|
%2 = arith.trunci %1 : i16 to i8
|
||||||
|
%3 = arith.cmpi sge, %2, %c-2_i8 : i8
|
||||||
|
%4 = arith.cmpi sle, %2, %c2_i8 : i8
|
||||||
|
%5 = arith.andi %3, %4 : i1
|
||||||
|
func.return %5 : i1
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue