[MLIR] Add division normalization by GCD in `getDivRepr` fn.

This commits adds division normalization in  the `getDivRepr` function which extracts
 the gcd from the dividend and divisor and normalizes them.

Signed-off-by: Prashant Kumar <pk5561@gmail.com>

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D115595
This commit is contained in:
Prashant Kumar 2022-01-06 16:12:41 +05:30 committed by Groverkss
parent 0fa174398b
commit df29318e66
2 changed files with 109 additions and 6 deletions

View File

@ -13,9 +13,38 @@
#include "mlir/Analysis/Presburger/Utils.h"
#include "mlir/Analysis/Presburger/IntegerPolyhedron.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
using namespace mlir;
/// Normalize a division's `dividend` and the `divisor` by their GCD. For
/// example: if the dividend and divisor are [2,0,4] and 4 respectively,
/// they get normalized to [1,0,2] and 2.
static void normalizeDivisionByGCD(SmallVectorImpl<int64_t> &dividend,
unsigned &divisor) {
if (divisor == 0 || dividend.empty())
return;
int64_t gcd = llvm::greatestCommonDivisor(dividend.front(), int64_t(divisor));
// The reason for ignoring the constant term is as follows.
// For a division:
// floor((a + m.f(x))/(m.d))
// It can be replaced by:
// floor((floor(a/m) + f(x))/d)
// Since `{a/m}/d` in the dividend satisfies 0 <= {a/m}/d < 1/d, it will not
// influence the result of the floor division and thus, can be ignored.
for (size_t i = 1, m = dividend.size() - 1; i < m; i++) {
gcd = llvm::greatestCommonDivisor(dividend[i], gcd);
if (gcd == 1)
return;
}
// Normalize the dividend and the denominator.
std::transform(dividend.begin(), dividend.end(), dividend.begin(),
[gcd](int64_t &n) { return floor(n / gcd); });
divisor /= gcd;
}
/// Check if the pos^th identifier can be represented as a division using upper
/// bound inequality at position `ubIneq` and lower bound inequality at position
/// `lbIneq`.
@ -52,7 +81,8 @@ using namespace mlir;
/// -divisor * id + expr - c >= 0 <-- Upper bound for 'id'
///
/// If successful, `expr` is set to dividend of the division and `divisor` is
/// set to the denominator of the division.
/// set to the denominator of the division. The final division expression is
/// normalized by GCD.
static LogicalResult getDivRepr(const IntegerPolyhedron &cst, unsigned pos,
unsigned ubIneq, unsigned lbIneq,
SmallVector<int64_t, 8> &expr,
@ -101,6 +131,7 @@ static LogicalResult getDivRepr(const IntegerPolyhedron &cst, unsigned pos,
// constant term of `expr`, minus `c`. From this,
// constant term of `expr` = constant term of upper bound + `c`.
expr.back() = cst.atIneq(ubIneq, cst.getNumCols() - 1) + c;
normalizeDivisionByGCD(expr, divisor);
return success();
}

View File

@ -592,12 +592,12 @@ TEST(FlatAffineConstraintsTest, computeLocalReprConstantFloorDiv) {
fac.addInequality({1, 2, -8, 1, 10});
fac.addEquality({1, 2, -4, 1, 10});
fac.addLocalFloorDiv({0, 0, 0, 0, 10}, 30);
fac.addLocalFloorDiv({0, 0, 0, 0, 0, 99}, 101);
fac.addLocalFloorDiv({0, 0, 0, 0, 100}, 30);
fac.addLocalFloorDiv({0, 0, 0, 0, 0, 206}, 101);
std::vector<SmallVector<int64_t, 8>> divisions = {{0, 0, 0, 0, 0, 0, 10},
{0, 0, 0, 0, 0, 0, 99}};
SmallVector<unsigned, 8> denoms = {30, 101};
std::vector<SmallVector<int64_t, 8>> divisions = {{0, 0, 0, 0, 0, 0, 3},
{0, 0, 0, 0, 0, 0, 2}};
SmallVector<unsigned, 8> denoms = {1, 1};
// Check if floordivs with constant numerator can be computed.
checkDivisionRepresentation(fac, divisions, denoms);
@ -750,6 +750,31 @@ TEST(FlatAffineConstraintsTest, mergeDivisionsSimple) {
EXPECT_EQ(fac1.getNumLocalIds(), 2u);
EXPECT_EQ(fac2.getNumLocalIds(), 2u);
}
{
// Division Normalization test.
// (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0).
FlatAffineConstraints fac1(1, 0, 1);
// This division would be normalized.
fac1.addLocalFloorDiv({3, 0, 0}, 6); // y = [3x / 6] -> [x/2].
fac1.addEquality({1, 0, -3, 0}); // x = 3z.
fac1.addInequality({1, 1, 0, 1}); // x + y + 1 >= 0.
// (x) : (exists y = [x / 2], z : x = 5y).
FlatAffineConstraints fac2(1);
fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
fac2.addEquality({1, -5, 0}); // x = 5y.
fac2.appendLocalId(); // Add local id z.
fac1.mergeLocalIds(fac2);
// Local space should be same.
EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
// One division should be matched + 2 unmatched local ids.
EXPECT_EQ(fac1.getNumLocalIds(), 3u);
EXPECT_EQ(fac2.getNumLocalIds(), 3u);
}
}
TEST(FlatAffineConstraintsTest, mergeDivisionsNestedDivsions) {
@ -800,6 +825,29 @@ TEST(FlatAffineConstraintsTest, mergeDivisionsNestedDivsions) {
EXPECT_EQ(fac1.getNumLocalIds(), 3u);
EXPECT_EQ(fac2.getNumLocalIds(), 3u);
}
{
// (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x).
FlatAffineConstraints fac1(1);
fac1.addLocalFloorDiv({2, 0}, 4); // y = [2x / 4] -> [x / 2].
fac1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3].
fac1.addInequality({-1, 1, 1, 0}); // y + z >= x.
// (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x).
FlatAffineConstraints fac2(1);
fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
// This division would be normalized.
fac2.addLocalFloorDiv({3, 3, 0}, 9); // z = [3x + 3y / 9] -> [x + y / 3].
fac2.addInequality({1, -1, -1, 0}); // y + z <= x.
fac1.mergeLocalIds(fac2);
// Local space should be same.
EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
// 2 divisions should be matched.
EXPECT_EQ(fac1.getNumLocalIds(), 2u);
EXPECT_EQ(fac2.getNumLocalIds(), 2u);
}
}
TEST(FlatAffineConstraintsTest, mergeDivisionsConstants) {
@ -821,6 +869,30 @@ TEST(FlatAffineConstraintsTest, mergeDivisionsConstants) {
// Local space should be same.
EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
// 2 divisions should be matched.
EXPECT_EQ(fac1.getNumLocalIds(), 2u);
EXPECT_EQ(fac2.getNumLocalIds(), 2u);
}
{
// (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x).
FlatAffineConstraints fac1(1);
fac1.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2].
// Normalization test.
fac1.addLocalFloorDiv({3, 0, 6}, 9); // z = [3x + 6 / 9] -> [x + 2 / 3].
fac1.addInequality({-1, 1, 1, 0}); // y + z >= x.
// (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x).
FlatAffineConstraints fac2(1);
// Normalization test.
fac2.addLocalFloorDiv({2, 2}, 4); // y = [2x + 2 / 4] -> [x + 1 / 2].
fac2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3].
fac2.addInequality({1, -1, -1, 0}); // y + z <= x.
fac1.mergeLocalIds(fac2);
// Local space should be same.
EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
// 2 divisions should be matched.
EXPECT_EQ(fac1.getNumLocalIds(), 2u);
EXPECT_EQ(fac2.getNumLocalIds(), 2u);