[MLIR][Presburger] introduce SlowMPInt, an auto-resizing APInt for fully correct signed integer computations

The Presburger library currently uses int64_t throughout for its integers.
This runs the risk of silently producing incorrect results when overflows occur.
Fixing this issue requires some sort of multiprecision integer
that transparently supports aribtrary arithmetic computations.

The class SlowMPInt provides this functionality, and is intended to be used
as the slow path fallback for a more optimized upcoming class, MPInt, that optimizes
for the Presburger library's workloads.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D123758
This commit is contained in:
Arjun P 2022-06-22 18:29:46 +02:00
parent cff4f04e2e
commit 628a2c14e3
5 changed files with 527 additions and 0 deletions

View File

@ -0,0 +1,135 @@
//===- SlowMPInt.h - MLIR SlowMPInt Class -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is a simple class to represent arbitrary precision signed integers.
// Unlike APInt, one does not have to specify a fixed maximum size, and the
// integer can take on any arbitrary values.
//
// This class is to be used as a fallback slow path for the MPInt class, and
// is not intended to be used directly.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_PRESBURGER_SLOWMPINT_H
#define MLIR_ANALYSIS_PRESBURGER_SLOWMPINT_H
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
namespace presburger {
namespace detail {
/// A simple class providing multi-precision arithmetic. Internally, it stores
/// an APInt, whose width is doubled whenever an overflow occurs at a certain
/// width. The default constructor sets the initial width to 64. SlowMPInt is
/// primarily intended to be used as a slow fallback path for the upcoming MPInt
/// class.
class SlowMPInt {
private:
llvm::APInt val;
public:
explicit SlowMPInt(int64_t val);
SlowMPInt();
explicit SlowMPInt(const llvm::APInt &val);
SlowMPInt &operator=(int64_t val);
explicit operator int64_t() const;
SlowMPInt operator-() const;
bool operator==(const SlowMPInt &o) const;
bool operator!=(const SlowMPInt &o) const;
bool operator>(const SlowMPInt &o) const;
bool operator<(const SlowMPInt &o) const;
bool operator<=(const SlowMPInt &o) const;
bool operator>=(const SlowMPInt &o) const;
SlowMPInt operator+(const SlowMPInt &o) const;
SlowMPInt operator-(const SlowMPInt &o) const;
SlowMPInt operator*(const SlowMPInt &o) const;
SlowMPInt operator/(const SlowMPInt &o) const;
SlowMPInt operator%(const SlowMPInt &o) const;
SlowMPInt &operator+=(const SlowMPInt &o);
SlowMPInt &operator-=(const SlowMPInt &o);
SlowMPInt &operator*=(const SlowMPInt &o);
SlowMPInt &operator/=(const SlowMPInt &o);
SlowMPInt &operator%=(const SlowMPInt &o);
SlowMPInt &operator++();
SlowMPInt &operator--();
friend SlowMPInt abs(const SlowMPInt &x);
friend SlowMPInt ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs);
friend SlowMPInt floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs);
friend SlowMPInt gcd(const SlowMPInt &a, const SlowMPInt &b);
/// Overload to compute a hash_code for a SlowMPInt value.
friend llvm::hash_code hash_value(const SlowMPInt &x); // NOLINT
void print(llvm::raw_ostream &os) const;
void dump() const;
unsigned getBitWidth() const { return val.getBitWidth(); }
};
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const SlowMPInt &x);
/// Returns the remainder of dividing LHS by RHS.
///
/// The RHS is always expected to be positive, and the result
/// is always non-negative.
SlowMPInt mod(const SlowMPInt &lhs, const SlowMPInt &rhs);
/// Returns the least common multiple of 'a' and 'b'.
SlowMPInt lcm(const SlowMPInt &a, const SlowMPInt &b);
/// Redeclarations of friend declarations above to
/// make it discoverable by lookups.
SlowMPInt abs(const SlowMPInt &x);
SlowMPInt ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs);
SlowMPInt floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs);
SlowMPInt gcd(const SlowMPInt &a, const SlowMPInt &b);
llvm::hash_code hash_value(const SlowMPInt &x); // NOLINT
/// ---------------------------------------------------------------------------
/// Convenience operator overloads for int64_t.
/// ---------------------------------------------------------------------------
SlowMPInt &operator+=(SlowMPInt &a, int64_t b);
SlowMPInt &operator-=(SlowMPInt &a, int64_t b);
SlowMPInt &operator*=(SlowMPInt &a, int64_t b);
SlowMPInt &operator/=(SlowMPInt &a, int64_t b);
SlowMPInt &operator%=(SlowMPInt &a, int64_t b);
bool operator==(const SlowMPInt &a, int64_t b);
bool operator!=(const SlowMPInt &a, int64_t b);
bool operator>(const SlowMPInt &a, int64_t b);
bool operator<(const SlowMPInt &a, int64_t b);
bool operator<=(const SlowMPInt &a, int64_t b);
bool operator>=(const SlowMPInt &a, int64_t b);
SlowMPInt operator+(const SlowMPInt &a, int64_t b);
SlowMPInt operator-(const SlowMPInt &a, int64_t b);
SlowMPInt operator*(const SlowMPInt &a, int64_t b);
SlowMPInt operator/(const SlowMPInt &a, int64_t b);
SlowMPInt operator%(const SlowMPInt &a, int64_t b);
bool operator==(int64_t a, const SlowMPInt &b);
bool operator!=(int64_t a, const SlowMPInt &b);
bool operator>(int64_t a, const SlowMPInt &b);
bool operator<(int64_t a, const SlowMPInt &b);
bool operator<=(int64_t a, const SlowMPInt &b);
bool operator>=(int64_t a, const SlowMPInt &b);
SlowMPInt operator+(int64_t a, const SlowMPInt &b);
SlowMPInt operator-(int64_t a, const SlowMPInt &b);
SlowMPInt operator*(int64_t a, const SlowMPInt &b);
SlowMPInt operator/(int64_t a, const SlowMPInt &b);
SlowMPInt operator%(int64_t a, const SlowMPInt &b);
} // namespace detail
} // namespace presburger
} // namespace mlir
#endif // MLIR_ANALYSIS_PRESBURGER_SLOWMPINT_H

View File

@ -6,6 +6,7 @@ add_mlir_library(MLIRPresburger
PresburgerSpace.cpp
PWMAFunction.cpp
Simplex.cpp
SlowMPInt.cpp
Utils.cpp
LINK_LIBS PUBLIC

View File

@ -0,0 +1,278 @@
//===- SlowMPInt.cpp - MLIR SlowMPInt Class -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/SlowMPInt.h"
#include "llvm/Support/MathExtras.h"
using namespace mlir;
using namespace presburger;
using namespace detail;
SlowMPInt::SlowMPInt(int64_t val) : val(64, val, /*isSigned=*/true) {}
SlowMPInt::SlowMPInt() : SlowMPInt(0) {}
SlowMPInt::SlowMPInt(const llvm::APInt &val) : val(val) {}
SlowMPInt &SlowMPInt::operator=(int64_t val) { return *this = SlowMPInt(val); }
SlowMPInt::operator int64_t() const { return val.getSExtValue(); }
llvm::hash_code detail::hash_value(const SlowMPInt &x) {
return hash_value(x.val);
}
/// ---------------------------------------------------------------------------
/// Printing.
/// ---------------------------------------------------------------------------
void SlowMPInt::print(llvm::raw_ostream &os) const { os << val; }
void SlowMPInt::dump() const { print(llvm::errs()); }
llvm::raw_ostream &detail::operator<<(llvm::raw_ostream &os,
const SlowMPInt &x) {
x.print(os);
return os;
}
/// ---------------------------------------------------------------------------
/// Convenience operator overloads for int64_t.
/// ---------------------------------------------------------------------------
SlowMPInt &detail::operator+=(SlowMPInt &a, int64_t b) {
return a += SlowMPInt(b);
}
SlowMPInt &detail::operator-=(SlowMPInt &a, int64_t b) {
return a -= SlowMPInt(b);
}
SlowMPInt &detail::operator*=(SlowMPInt &a, int64_t b) {
return a *= SlowMPInt(b);
}
SlowMPInt &detail::operator/=(SlowMPInt &a, int64_t b) {
return a /= SlowMPInt(b);
}
SlowMPInt &detail::operator%=(SlowMPInt &a, int64_t b) {
return a %= SlowMPInt(b);
}
bool detail::operator==(const SlowMPInt &a, int64_t b) {
return a == SlowMPInt(b);
}
bool detail::operator!=(const SlowMPInt &a, int64_t b) {
return a != SlowMPInt(b);
}
bool detail::operator>(const SlowMPInt &a, int64_t b) {
return a > SlowMPInt(b);
}
bool detail::operator<(const SlowMPInt &a, int64_t b) {
return a < SlowMPInt(b);
}
bool detail::operator<=(const SlowMPInt &a, int64_t b) {
return a <= SlowMPInt(b);
}
bool detail::operator>=(const SlowMPInt &a, int64_t b) {
return a >= SlowMPInt(b);
}
SlowMPInt detail::operator+(const SlowMPInt &a, int64_t b) {
return a + SlowMPInt(b);
}
SlowMPInt detail::operator-(const SlowMPInt &a, int64_t b) {
return a - SlowMPInt(b);
}
SlowMPInt detail::operator*(const SlowMPInt &a, int64_t b) {
return a * SlowMPInt(b);
}
SlowMPInt detail::operator/(const SlowMPInt &a, int64_t b) {
return a / SlowMPInt(b);
}
SlowMPInt detail::operator%(const SlowMPInt &a, int64_t b) {
return a % SlowMPInt(b);
}
bool detail::operator==(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) == b;
}
bool detail::operator!=(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) != b;
}
bool detail::operator>(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) > b;
}
bool detail::operator<(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) < b;
}
bool detail::operator<=(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) <= b;
}
bool detail::operator>=(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) >= b;
}
SlowMPInt detail::operator+(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) + b;
}
SlowMPInt detail::operator-(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) - b;
}
SlowMPInt detail::operator*(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) * b;
}
SlowMPInt detail::operator/(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) / b;
}
SlowMPInt detail::operator%(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) % b;
}
static unsigned getMaxWidth(const APInt &a, const APInt &b) {
return std::max(a.getBitWidth(), b.getBitWidth());
}
/// ---------------------------------------------------------------------------
/// Comparison operators.
/// ---------------------------------------------------------------------------
// TODO: consider instead making APInt::compare available and using that.
bool SlowMPInt::operator==(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width) == o.val.sext(width);
}
bool SlowMPInt::operator!=(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width) != o.val.sext(width);
}
bool SlowMPInt::operator>(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width).sgt(o.val.sext(width));
}
bool SlowMPInt::operator<(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width).slt(o.val.sext(width));
}
bool SlowMPInt::operator<=(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width).sle(o.val.sext(width));
}
bool SlowMPInt::operator>=(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width).sge(o.val.sext(width));
}
/// ---------------------------------------------------------------------------
/// Arithmetic operators.
/// ---------------------------------------------------------------------------
/// Bring a and b to have the same width and then call op(a, b, overflow).
/// If the overflow bit becomes set, resize a and b to double the width and
/// call op(a, b, overflow), returning its result. The operation with double
/// widths should not also overflow.
APInt runOpWithExpandOnOverflow(
const APInt &a, const APInt &b,
llvm::function_ref<APInt(const APInt &, const APInt &, bool &overflow)>
op) {
bool overflow;
unsigned width = getMaxWidth(a, b);
APInt ret = op(a.sext(width), b.sext(width), overflow);
if (!overflow)
return ret;
width *= 2;
ret = op(a.sext(width), b.sext(width), overflow);
assert(!overflow && "double width should be sufficient to avoid overflow!");
return ret;
}
SlowMPInt SlowMPInt::operator+(const SlowMPInt &o) const {
return SlowMPInt(
runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sadd_ov)));
}
SlowMPInt SlowMPInt::operator-(const SlowMPInt &o) const {
return SlowMPInt(
runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::ssub_ov)));
}
SlowMPInt SlowMPInt::operator*(const SlowMPInt &o) const {
return SlowMPInt(
runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::smul_ov)));
}
SlowMPInt SlowMPInt::operator/(const SlowMPInt &o) const {
return SlowMPInt(
runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sdiv_ov)));
}
SlowMPInt detail::abs(const SlowMPInt &x) { return x >= 0 ? x : -x; }
SlowMPInt detail::ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
if (rhs == -1)
return -lhs;
return SlowMPInt(
llvm::APIntOps::RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::UP));
}
SlowMPInt detail::floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
if (rhs == -1)
return -lhs;
return SlowMPInt(
llvm::APIntOps::RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::DOWN));
}
// The RHS is always expected to be positive, and the result
/// is always non-negative.
SlowMPInt detail::mod(const SlowMPInt &lhs, const SlowMPInt &rhs) {
assert(rhs >= 1 && "mod is only supported for positive divisors!");
return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs;
}
SlowMPInt detail::gcd(const SlowMPInt &a, const SlowMPInt &b) {
return SlowMPInt(
llvm::APIntOps::GreatestCommonDivisor(a.val.abs(), b.val.abs()));
}
/// Returns the least common multiple of 'a' and 'b'.
SlowMPInt detail::lcm(const SlowMPInt &a, const SlowMPInt &b) {
SlowMPInt x = abs(a);
SlowMPInt y = abs(b);
return (x * y) / gcd(x, y);
}
/// This operation cannot overflow.
SlowMPInt SlowMPInt::operator%(const SlowMPInt &o) const {
unsigned width = std::max(val.getBitWidth(), o.val.getBitWidth());
return SlowMPInt(val.sext(width).srem(o.val.sext(width)));
}
SlowMPInt SlowMPInt::operator-() const {
if (val.isMinSignedValue()) {
/// Overflow only occurs when the value is the minimum possible value.
APInt ret = val.sext(2 * val.getBitWidth());
return SlowMPInt(-ret);
}
return SlowMPInt(-val);
}
/// ---------------------------------------------------------------------------
/// Assignment operators, preincrement, predecrement.
/// ---------------------------------------------------------------------------
SlowMPInt &SlowMPInt::operator+=(const SlowMPInt &o) {
*this = *this + o;
return *this;
}
SlowMPInt &SlowMPInt::operator-=(const SlowMPInt &o) {
*this = *this - o;
return *this;
}
SlowMPInt &SlowMPInt::operator*=(const SlowMPInt &o) {
*this = *this * o;
return *this;
}
SlowMPInt &SlowMPInt::operator/=(const SlowMPInt &o) {
*this = *this / o;
return *this;
}
SlowMPInt &SlowMPInt::operator%=(const SlowMPInt &o) {
*this = *this % o;
return *this;
}
SlowMPInt &SlowMPInt::operator++() {
*this += 1;
return *this;
}
SlowMPInt &SlowMPInt::operator--() {
*this -= 1;
return *this;
}

View File

@ -7,6 +7,7 @@ add_mlir_unittest(MLIRPresburgerTests
PresburgerSpaceTest.cpp
PWMAFunctionTest.cpp
SimplexTest.cpp
SlowMPIntTest.cpp
../../Dialect/Affine/Analysis/AffineStructuresParser.cpp
)

View File

@ -0,0 +1,112 @@
//===- SlowMPIntTest.cpp - Tests for SlowMPInt ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/SlowMPInt.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
using namespace mlir;
using namespace presburger;
using detail::SlowMPInt;
TEST(SlowMPIntTest, ops) {
SlowMPInt two(2), five(5), seven(7), ten(10);
EXPECT_EQ(five + five, ten);
EXPECT_EQ(five * five, 2 * ten + five);
EXPECT_EQ(five * five, 3 * ten - five);
EXPECT_EQ(five * two, ten);
EXPECT_EQ(five / two, two);
EXPECT_EQ(five % two, two / two);
EXPECT_EQ(-ten % seven, -10 % 7);
EXPECT_EQ(ten % -seven, 10 % -7);
EXPECT_EQ(-ten % -seven, -10 % -7);
EXPECT_EQ(ten % seven, 10 % 7);
EXPECT_EQ(-ten / seven, -10 / 7);
EXPECT_EQ(ten / -seven, 10 / -7);
EXPECT_EQ(-ten / -seven, -10 / -7);
EXPECT_EQ(ten / seven, 10 / 7);
SlowMPInt x = ten;
x += five;
EXPECT_EQ(x, 15);
x *= two;
EXPECT_EQ(x, 30);
x /= seven;
EXPECT_EQ(x, 4);
x -= two * 10;
EXPECT_EQ(x, -16);
x *= 2 * two;
EXPECT_EQ(x, -64);
x /= two / -2;
EXPECT_EQ(x, 64);
EXPECT_LE(ten, ten);
EXPECT_GE(ten, ten);
EXPECT_EQ(ten, ten);
EXPECT_FALSE(ten != ten);
EXPECT_FALSE(ten < ten);
EXPECT_FALSE(ten > ten);
EXPECT_LT(five, ten);
EXPECT_GT(ten, five);
}
TEST(SlowMPIntTest, ops64Overloads) {
SlowMPInt two(2), five(5), seven(7), ten(10);
EXPECT_EQ(five + 5, ten);
EXPECT_EQ(five + 5, 5 + five);
EXPECT_EQ(five * 5, 2 * ten + 5);
EXPECT_EQ(five * 5, 3 * ten - 5);
EXPECT_EQ(five * two, ten);
EXPECT_EQ(5 / two, 2);
EXPECT_EQ(five / 2, 2);
EXPECT_EQ(2 % two, 0);
EXPECT_EQ(2 - two, 0);
EXPECT_EQ(2 % two, two % 2);
SlowMPInt x = ten;
x += 5;
EXPECT_EQ(x, 15);
x *= 2;
EXPECT_EQ(x, 30);
x /= 7;
EXPECT_EQ(x, 4);
x -= 20;
EXPECT_EQ(x, -16);
x *= 4;
EXPECT_EQ(x, -64);
x /= -1;
EXPECT_EQ(x, 64);
EXPECT_LE(ten, 10);
EXPECT_GE(ten, 10);
EXPECT_EQ(ten, 10);
EXPECT_FALSE(ten != 10);
EXPECT_FALSE(ten < 10);
EXPECT_FALSE(ten > 10);
EXPECT_LT(five, 10);
EXPECT_GT(ten, 5);
EXPECT_LE(10, ten);
EXPECT_GE(10, ten);
EXPECT_EQ(10, ten);
EXPECT_FALSE(10 != ten);
EXPECT_FALSE(10 < ten);
EXPECT_FALSE(10 > ten);
EXPECT_LT(5, ten);
EXPECT_GT(10, five);
}
TEST(SlowMPIntTest, overflows) {
SlowMPInt x(1ll << 60);
EXPECT_EQ((x * x - x * x * x * x) / (x * x * x), 1 - (1ll << 60));
SlowMPInt y(1ll << 62);
EXPECT_EQ((y + y + y + y + y + y) / y, 6);
EXPECT_EQ(-(2 * (-y)), 2 * y); // -(-2^63) overflow.
}