[MLIR][Presburger] Support lexicographic max/min union of two PWMAFunction

This patch implements a lexicographic max/min union of two PWMAFunctions.

The lexmax/lexmin union of two functions is defined as a function defined on
the union of the input domains of both functions, such that when only one of the
functions are defined, it outputs the same as that function, and if both are
defined, it outputs the lexmax/lexmin of the two outputs. On points where
neither function is defined, the union is not defined either.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D128829
This commit is contained in:
Groverkss 2022-07-06 16:08:15 +01:00
parent b5b6d3a41b
commit a18f843f07
5 changed files with 534 additions and 2 deletions

View File

@ -118,7 +118,7 @@ public:
/// intersection with no simplification of any sort attempted.
void append(const IntegerRelation &other);
/// Return the intersection of the two sets.
/// Return the intersection of the two relations.
/// If there are locals, they will be merged.
IntegerRelation intersect(IntegerRelation other) const;
@ -608,6 +608,10 @@ public:
/// `PresburgerSet`, `unboundedDomain`.
SymbolicLexMin findSymbolicIntegerLexMin() const;
/// Return the set difference of this set and the given set, i.e.,
/// return `this \ set`.
PresburgerRelation subtract(const PresburgerRelation &set) const;
void print(raw_ostream &os) const;
void dump() const;
@ -790,6 +794,14 @@ public:
/// column position (i.e., not relative to the kind of variable) of the
/// first added variable.
unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override;
/// Return the intersection of the two relations.
/// If there are locals, they will be merged.
IntegerPolyhedron intersect(const IntegerPolyhedron &other) const;
/// Return the set difference of this set and the given set, i.e.,
/// return `this \ set`.
PresburgerSet subtract(const PresburgerSet &other) const;
};
} // namespace presburger

View File

@ -54,8 +54,15 @@ public:
bool isConsistent() const {
return output.getNumColumns() == domainSet.getNumVars() + 1;
}
const IntegerPolyhedron &getDomain() const { return domainSet; }
/// Get the space of the input domain of this function.
const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); }
/// Get the input domain of this function.
const IntegerPolyhedron &getDomain() const { return domainSet; }
/// Get a matrix with each row representing row^th output expression.
const Matrix &getOutputMatrix() const { return output; }
/// Get the `i^th` output expression.
ArrayRef<int64_t> getOutputExpr(unsigned i) const { return output.getRow(i); }
/// Insert `num` variables of the specified kind at position `pos`.
/// Positions are relative to the kind of variable. The coefficient columns
@ -138,6 +145,7 @@ public:
void addPiece(const MultiAffineFunction &piece);
void addPiece(const IntegerPolyhedron &domain, const Matrix &output);
void addPiece(const PresburgerSet &domain, const Matrix &output);
const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; }
unsigned getNumPieces() const { return pieces.size(); }
@ -163,10 +171,41 @@ public:
/// TODO: refactor so that this can be accomplished through removeVarRange.
void truncateOutput(unsigned count);
/// Return a function defined on the union of the domains of this and func,
/// such that when only one of the functions is defined, it outputs the same
/// as that function, and if both are defined, it outputs the lexmax/lexmin of
/// the two outputs. On points where neither function is defined, the returned
/// function is not defined either.
///
/// Currently this does not support PWMAFunctions which have pieces containing
/// local variables.
/// TODO: Support local variables in peices.
PWMAFunction unionLexMin(const PWMAFunction &func);
PWMAFunction unionLexMax(const PWMAFunction &func);
void print(raw_ostream &os) const;
void dump() const;
private:
/// Return a function defined on the union of the domains of `this` and
/// `func`, such that when only one of the functions is defined, it outputs
/// the same as that function, and if neither is defined, the returned
/// function is not defined either.
///
/// The provided `tiebreak` function determines which of the two functions'
/// output should be used on inputs where both the functions are defined. More
/// precisely, given two `MultiAffineFunction`s `mafA` and `mafB`, `tiebreak`
/// returns the subset of the intersection of the two functions' domains where
/// the output of `mafA` should be used.
///
/// The PresburgerSet returned by `tiebreak` should be disjoint.
/// TODO: Remove this constraint of returning disjoint set.
PWMAFunction
unionFunction(const PWMAFunction &func,
llvm::function_ref<PresburgerSet(MultiAffineFunction mafA,
MultiAffineFunction mafB)>
tiebreak) const;
PresburgerSpace space;
/// The list of pieces in this piece-wise MultiAffineFunction.

View File

@ -252,6 +252,11 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
return result;
}
PresburgerRelation
IntegerRelation::subtract(const PresburgerRelation &set) const {
return PresburgerRelation(*this).subtract(set);
}
unsigned IntegerRelation::insertVar(VarKind kind, unsigned pos, unsigned num) {
assert(pos <= getNumVarKind(kind));
@ -2284,3 +2289,11 @@ unsigned IntegerPolyhedron::insertVar(VarKind kind, unsigned pos,
"Domain has to be zero in a set");
return IntegerRelation::insertVar(kind, pos, num);
}
IntegerPolyhedron
IntegerPolyhedron::intersect(const IntegerPolyhedron &other) const {
return IntegerPolyhedron(IntegerRelation::intersect(other));
}
PresburgerSet IntegerPolyhedron::subtract(const PresburgerSet &other) const {
return PresburgerSet(IntegerRelation::subtract(other));
}

View File

@ -211,6 +211,11 @@ void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
addPiece(MultiAffineFunction(domain, output));
}
void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) {
for (const IntegerRelation &newDom : domain.getAllDisjuncts())
addPiece(IntegerPolyhedron(newDom), output);
}
void PWMAFunction::print(raw_ostream &os) const {
os << pieces.size() << " pieces:\n";
for (const MultiAffineFunction &piece : pieces)
@ -218,3 +223,138 @@ void PWMAFunction::print(raw_ostream &os) const {
}
void PWMAFunction::dump() const { print(llvm::errs()); }
PWMAFunction PWMAFunction::unionFunction(
const PWMAFunction &func,
llvm::function_ref<PresburgerSet(MultiAffineFunction maf1,
MultiAffineFunction maf2)>
tiebreak) const {
assert(getNumOutputs() == func.getNumOutputs() &&
"Number of outputs of functions should be same.");
assert(getSpace().isCompatible(func.getSpace()) &&
"Space is not compatible.");
// The algorithm used here is as follows:
// - Add the output of funcB for the part of the domain where both funcA and
// funcB are defined, and `tiebreak` chooses the output of funcB.
// - Add the output of funcA, where funcB is not defined or `tiebreak` chooses
// funcA over funcB.
// - Add the output of funcB, where funcA is not defined.
// Add parts of the common domain where funcB's output is used. Also
// add all the parts where funcA's output is used, both common and non-common.
PWMAFunction result(getSpace(), getNumOutputs());
for (const MultiAffineFunction &funcA : pieces) {
PresburgerSet dom(funcA.getDomain());
for (const MultiAffineFunction &funcB : func.pieces) {
PresburgerSet better = tiebreak(funcB, funcA);
// Add the output of funcB, where it is better than output of funcA.
// The disjuncts in "better" will be disjoint as tiebreak should gurantee
// that.
result.addPiece(better, funcB.getOutputMatrix());
dom = dom.subtract(better);
}
// Add output of funcA, where it is better than funcB, or funcB is not
// defined.
//
// `dom` here is guranteed to be disjoint from already added pieces
// because because the pieces added before are either:
// - Subsets of the domain of other MAFs in `this`, which are guranteed
// to be disjoint from `dom`, or
// - They are one of the pieces added for `funcB`, and we have been
// subtracting all such pieces from `dom`, so `dom` is disjoint from those
// pieces as well.
result.addPiece(dom, funcA.getOutputMatrix());
}
// Add parts of funcB which are not shared with funcA.
PresburgerSet dom = getDomain();
for (const MultiAffineFunction &funcB : func.pieces)
result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix());
return result;
}
/// A tiebreak function which breaks ties by comparing the outputs
/// lexicographically. If `lexMin` is true, then the ties are broken by
/// taking the lexicographically smaller output and otherwise, by taking the
/// lexicographically larger output.
template <bool lexMin>
static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
const MultiAffineFunction &mafB) {
// TODO: Support local variables here.
assert(mafA.getDomainSpace().isCompatible(mafB.getDomainSpace()) &&
"Domain spaces should be compatible.");
assert(mafA.getNumOutputs() == mafB.getNumOutputs() &&
"Number of outputs of both functions should be same.");
assert(mafA.getDomain().getNumLocalVars() == 0 &&
"Local variables are not supported yet.");
PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals();
const PresburgerSpace &space = mafA.getDomain().getSpace();
// We first create the set `result`, corresponding to the set where output
// of mafA is lexicographically larger/smaller than mafB. This is done by
// creating a PresburgerSet with the following constraints:
//
// (outA[0] > outB[0]) U
// (outA[0] = outB[0], outA[1] > outA[1]) U
// (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
// ...
// (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
//
// where `n` is the number of outputs.
// If `lexMin` is set, the complement inequality is used:
//
// (outA[0] < outB[0]) U
// (outA[0] = outB[0], outA[1] < outA[1]) U
// (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
// ...
// (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace);
IntegerPolyhedron levelSet(/*numReservedInequalities=*/1,
/*numReservedEqualities=*/mafA.getNumOutputs(),
/*numReservedCols=*/space.getNumVars() + 1, space);
for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) {
// Create the expression `outA - outB` for this level.
SmallVector<int64_t, 8> subExpr =
subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level));
if (lexMin) {
// For lexMin, we add an upper bound of -1:
// outA - outB <= -1
// outA <= outB - 1
// outA < outB
levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, -1);
} else {
// For lexMax, we add a lower bound of 1:
// outA - outB >= 1
// outA > outB + 1
// outA > outB
levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, 1);
}
// Union the set with the result.
result.unionInPlace(levelSet);
// There is only 1 inequality in `levelSet`, so the index is always 0.
levelSet.removeInequality(0);
// Add equality `outA - outB == 0` for this level for next iteration.
levelSet.addEquality(subExpr);
}
// We then intersect `result` with the domain of mafA and mafB, to only
// tiebreak on the domain where both are defined.
result = result.intersect(PresburgerSet(mafA.getDomain()))
.intersect(PresburgerSet(mafB.getDomain()));
return result;
}
PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
return unionFunction(func, tiebreakLex</*lexMin=*/true>);
}
PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
return unionFunction(func, tiebreakLex</*lexMin=*/false>);
}

View File

@ -189,3 +189,331 @@ TEST(PWMAFunction, eliminateRedundantLocalIdRegressionTest) {
});
EXPECT_TRUE(pwmafA.isEqual(pwmafB));
}
TEST(PWMAFunction, unionLexMaxSimple) {
// func2 is better than func1, but func2's domain is empty.
{
PWMAFunction func1 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : ()", {{0, 1}}},
});
PWMAFunction func2 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (1 == 0)", {{0, 2}}},
});
EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func1));
EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func1));
}
// func2 is better than func1 on a subset of func1.
{
PWMAFunction func1 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : ()", {{0, 1}}},
});
PWMAFunction func2 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}},
});
PWMAFunction result = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (-1 - x >= 0)", {{0, 1}}},
{"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}},
{"(x) : (x - 11 >= 0)", {{0, 1}}},
});
EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
}
// func1 and func2 are defined over the whole domain with different outputs.
{
PWMAFunction func1 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : ()", {{1, 0}}},
});
PWMAFunction func2 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : ()", {{-1, 0}}},
});
PWMAFunction result = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (x >= 0)", {{1, 0}}},
{"(x) : (-1 - x >= 0)", {{-1, 0}}},
});
EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
}
// func1 and func2 have disjoint domains.
{
PWMAFunction func1 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}},
{"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}},
});
PWMAFunction func2 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}},
{"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}},
});
PWMAFunction result = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}},
{"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}},
{"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}},
{"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}},
});
EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
}
}
TEST(PWMAFunction, unionLexMinSimple) {
// func2 is better than func1, but func2's domain is empty.
{
PWMAFunction func1 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : ()", {{0, -1}}},
});
PWMAFunction func2 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (1 == 0)", {{0, -2}}},
});
EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func1));
EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func1));
}
// func2 is better than func1 on a subset of func1.
{
PWMAFunction func1 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : ()", {{0, -1}}},
});
PWMAFunction func2 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}},
});
PWMAFunction result = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (-1 - x >= 0)", {{0, -1}}},
{"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}},
{"(x) : (x - 11 >= 0)", {{0, -1}}},
});
EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
}
// func1 and func2 are defined over the whole domain with different outputs.
{
PWMAFunction func1 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : ()", {{-1, 0}}},
});
PWMAFunction func2 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : ()", {{1, 0}}},
});
PWMAFunction result = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (x >= 0)", {{-1, 0}}},
{"(x) : (-1 - x >= 0)", {{1, 0}}},
});
EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
}
}
TEST(PWMAFunction, unionLexMaxComplex) {
// Union of function containing 4 different pieces of output.
//
// x >= 21 --> func1 (func2 not defined)
// x <= 0 --> func2 (func1 not defined)
// 10 <= x <= 20, y > 0 --> func1 (x + y > x - y for y > 0)
// 10 <= x <= 20, y <= 0 --> func2 (x + y <= x - y for y <= 0)
{
PWMAFunction func1 = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/1,
{
{"(x, y) : (x >= 10)", {{1, 1, 0}}},
});
PWMAFunction func2 = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/1,
{
{"(x, y) : (x <= 20)", {{1, -1, 0}}},
});
PWMAFunction result = parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1,
{{"(x, y) : (x >= 10, x <= 20, y >= 1)",
{
{1, 1, 0},
}},
{"(x, y) : (x >= 21)",
{
{1, 1, 0},
}},
{"(x, y) : (x <= 9)",
{
{1, -1, 0},
}},
{"(x, y) : (x >= 10, x <= 20, y <= 0)",
{
{1, -1, 0},
}}});
EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
}
// Functions with more than one output, with contribution from both functions.
//
// If y >= 1, func1 is better because in the first output,
// x + y (func1) > x (func2), when y >= 1
//
// If y == 0, the first output is same for both functions, so we look at the
// second output. -2x + 4 (func1) > 2x - 2 (func2) when 0 <= x <= 1, so we
// take func1 for this domain and func2 for the remaining.
{
PWMAFunction func1 = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (x >= 0, y >= 0)", {{1, 1, 0}, {-2, 0, 4}}},
});
PWMAFunction func2 = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (x >= 0, y >= 0)", {{1, 0, 0}, {2, 0, -2}}},
});
PWMAFunction result = parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2,
{{"(x, y) : (x >= 0, y >= 1)",
{
{1, 1, 0},
{-2, 0, 4},
}},
{"(x, y) : (x >= 0, x <= 1, y == 0)",
{
{1, 1, 0},
{-2, 0, 4},
}},
{"(x, y) : (x >= 2, y == 0)",
{
{1, 0, 0},
{2, 0, -2},
}}});
EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
}
// Function with three boolean variables `a, b, c` used to control which
// output will be taken lexicographically.
//
// a == 1 --> Take func2
// a == 0, b == 1 --> Take func1
// a == 0, b == 0, c == 1 --> Take func2
{
PWMAFunction func1 = parsePWMAF(
/*numInputs=*/3, /*numOutputs=*/3,
{
{"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c "
">= 0, 1 - c >= 0)",
{{0, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 0}}},
});
PWMAFunction func2 = parsePWMAF(
/*numInputs=*/3, /*numOutputs=*/3,
{
{"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c >= 0, 1 - "
"c >= 0)",
{{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}},
});
PWMAFunction result = parsePWMAF(
/*numInputs=*/3, /*numOutputs=*/3,
{
{"(a, b, c) : (a - 1 == 0, b >= 0, 1 - b >= 0, c >= 0, 1 - c >= 0)",
{{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}},
{"(a, b, c) : (a == 0, b - 1 == 0, c >= 0, 1 - c >= 0)",
{{0, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 0}}},
{"(a, b, c) : (a == 0, b == 0, c >= 0, 1 - c >= 0)",
{{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}},
});
EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
}
}
TEST(PWMAFunction, unionLexMinComplex) {
// Regression test checking if lexicographic tiebreak produces disjoint
// domains.
//
// If x == 1, func1 is better since in the first output,
// -x (func1) is < 0 (func2) when x == 1.
//
// If x == 0, func1 and func2 both have the same first output. So we take a
// look at the second output. func2 is better since in the second output,
// y - 1 (func2) is < y (func1).
PWMAFunction func1 = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)",
{{-1, 0, 0}, {0, 1, 0}}},
});
PWMAFunction func2 = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)",
{{0, 0, 0}, {0, 1, -1}}},
});
PWMAFunction result = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (x == 1, y >= 0, y <= 1)", {{-1, 0, 0}, {0, 1, 0}}},
{"(x, y) : (x == 0, y >= 0, y <= 1)", {{0, 0, 0}, {0, 1, -1}}},
});
EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
}