[mlir] Remove SDBM

This data structure and algorithm collection is no longer in use.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D105102
This commit is contained in:
Alex Zinenko 2021-06-29 14:13:01 +02:00
parent 47215e1c62
commit 355216380b
19 changed files with 0 additions and 2949 deletions

View File

@ -1,197 +0,0 @@
//===- SDBM.h - MLIR SDBM declaration ---------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined
// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SDBM_SDBM_H
#define MLIR_DIALECT_SDBM_SDBM_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
class MLIRContext;
class SDBMDialect;
class SDBMExpr;
class SDBMTermExpr;
/// A utility class for SDBM to represent an integer with potentially infinite
/// positive value. This uses the largest value of int64_t to represent infinity
/// and redefines the arithmetic operators so that the infinity "saturates":
/// inf + x = inf,
/// inf - x = inf.
/// If a sum of two finite values reaches the largest value of int64_t, the
/// behavior of IntInfty is undefined (in practice, it asserts), similarly to
/// regular signed integer overflow.
class IntInfty {
public:
constexpr static int64_t infty = std::numeric_limits<int64_t>::max();
/*implicit*/ IntInfty(int64_t v) : value(v) {}
IntInfty &operator=(int64_t v) {
value = v;
return *this;
}
static IntInfty infinity() { return IntInfty(infty); }
int64_t getValue() const { return value; }
explicit operator int64_t() const { return value; }
bool isFinite() { return value != infty; }
private:
int64_t value;
};
inline IntInfty operator+(IntInfty lhs, IntInfty rhs) {
if (!lhs.isFinite() || !rhs.isFinite())
return IntInfty::infty;
// Check for overflows, treating the sum of two values adding up to INT_MAX as
// overflow. Convert values to unsigned to get an extra bit and avoid the
// undefined behavior of signed integer overflows.
assert((lhs.getValue() <= 0 || rhs.getValue() <= 0 ||
static_cast<uint64_t>(lhs.getValue()) +
static_cast<uint64_t>(rhs.getValue()) <
static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) &&
"IntInfty overflow");
// Check for underflows by converting values to unsigned to avoid undefined
// behavior of signed integers perform the addition (bitwise result is same
// because numbers are required to be two's complement in C++) and check if
// the sign bit remains negative.
assert((lhs.getValue() >= 0 || rhs.getValue() >= 0 ||
((static_cast<uint64_t>(lhs.getValue()) +
static_cast<uint64_t>(rhs.getValue())) >>
63) == 1) &&
"IntInfty underflow");
return lhs.getValue() + rhs.getValue();
}
inline bool operator<(IntInfty lhs, IntInfty rhs) {
return lhs.getValue() < rhs.getValue();
}
inline bool operator<=(IntInfty lhs, IntInfty rhs) {
return lhs.getValue() <= rhs.getValue();
}
inline bool operator==(IntInfty lhs, IntInfty rhs) {
return lhs.getValue() == rhs.getValue();
}
inline bool operator!=(IntInfty lhs, IntInfty rhs) { return !(lhs == rhs); }
/// Striped difference-bound matrix is a representation of an integer set bound
/// by a system of SDBMExprs interpreted as inequalities "expr <= 0".
class SDBM {
public:
/// Obtain an SDBM from a list of SDBM expressions treated as inequalities and
/// equalities with zero.
static SDBM get(ArrayRef<SDBMExpr> inequalities,
ArrayRef<SDBMExpr> equalities);
void getSDBMExpressions(SDBMDialect *dialect,
SmallVectorImpl<SDBMExpr> &inequalities,
SmallVectorImpl<SDBMExpr> &equalities);
void print(raw_ostream &os);
void dump();
IntInfty operator()(int i, int j) { return at(i, j); }
private:
/// Get the given element of the difference bounds matrix. First index
/// corresponds to the negative term of the difference, second index
/// corresponds to the positive term of the difference.
IntInfty &at(int i, int j) { return matrix[i * getNumVariables() + j]; }
/// Populate `inequalities` and `equalities` based on the values at(row,col)
/// and at(col,row) of the DBM. Depending on the values being finite and
/// being subsumed by stripe expressions, this may or may not add elements to
/// the lists of equalities and inequalities.
void convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr,
SDBMTermExpr colExpr,
SmallVectorImpl<SDBMExpr> &inequalities,
SmallVectorImpl<SDBMExpr> &equalities);
/// Populate `inequalities` based on the value at(pos,pos) of the DBM. Only
/// adds new inequalities if the inequality is not trivially true.
void convertDBMDiagonalElement(unsigned pos, SDBMTermExpr expr,
SmallVectorImpl<SDBMExpr> &inequalities);
/// Get the total number of elements in the matrix.
unsigned getNumVariables() const {
return 1 + numDims + numSymbols + numTemporaries;
}
/// Get the position in the matrix that corresponds to the given dimension.
unsigned getDimPosition(unsigned position) const { return 1 + position; }
/// Get the position in the matrix that corresponds to the given symbol.
unsigned getSymbolPosition(unsigned position) const {
return 1 + numDims + position;
}
/// Get the position in the matrix that corresponds to the given temporary.
unsigned getTemporaryPosition(unsigned position) const {
return 1 + numDims + numSymbols + position;
}
/// Number of dimensions in the system,
unsigned numDims;
/// Number of symbols in the system.
unsigned numSymbols;
/// Number of temporary variables in the system.
unsigned numTemporaries;
/// Difference bounds matrix, stored as a linearized row-major vector.
/// Each value in this matrix corresponds to an inequality
///
/// v@col - v@row <= at(row, col)
///
/// where v@col and v@row are the variables that correspond to the linearized
/// position in the matrix. The positions correspond to
///
/// - constant 0 (producing constraints v@col <= X and -v@row <= Y);
/// - SDBM expression dimensions (d0, d1, ...);
/// - SDBM expression symbols (s0, s1, ...);
/// - temporary variables (t0, t1, ...).
///
/// Temporary variables are introduced to represent expressions that are not
/// trivially a difference between two variables. For example, if one side of
/// a difference expression is itself a stripe expression, it will be replaced
/// with a temporary variable assigned equal to this expression.
///
/// Infinite entries in the matrix correspond correspond to an absence of a
/// constraint:
///
/// v@col - v@row <= infinity
///
/// is trivially true. Negated values at symmetric positions in the matrix
/// allow one to couple two inequalities into a single equality.
std::vector<IntInfty> matrix;
/// The mapping between the indices of variables in the DBM and the stripe
/// expressions they are equal to. These expressions are stored as they
/// appeared when constructing an SDBM from a SDBMExprs, in particular no
/// temporaries can appear in these expressions. This removes the need to
/// iteratively substitute definitions of the temporaries in the reverse
/// conversion.
DenseMap<unsigned, SDBMExpr> stripeToPoint;
};
} // namespace mlir
#endif // MLIR_DIALECT_SDBM_SDBM_H

View File

@ -1,37 +0,0 @@
//===- SDBMDialect.h - Dialect for striped DBMs -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SDBM_SDBMDIALECT_H
#define MLIR_DIALECT_SDBM_SDBMDIALECT_H
#include "mlir/IR/Dialect.h"
#include "mlir/Support/StorageUniquer.h"
namespace mlir {
class MLIRContext;
class SDBMDialect : public Dialect {
public:
SDBMDialect(MLIRContext *context);
/// Since there are no other virtual methods in this derived class, override
/// the destructor so that key methods get defined in the corresponding
/// module.
~SDBMDialect() override;
static StringRef getDialectNamespace() { return "sdbm"; }
/// Get the uniquer for SDBM expressions. This should not be used directly.
StorageUniquer &getUniquer() { return uniquer; }
private:
StorageUniquer uniquer;
};
} // namespace mlir
#endif // MLIR_DIALECT_SDBM_SDBMDIALECT_H

View File

@ -1,576 +0,0 @@
//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A striped difference-bound matrix (SDBM) expression is a constant expression,
// an identifier, a binary expression with constant RHS and +, stripe operators
// or a difference expression between two identifiers.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SDBM_SDBMEXPR_H
#define MLIR_DIALECT_SDBM_SDBMEXPR_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
namespace mlir {
class AffineExpr;
class MLIRContext;
enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg };
namespace detail {
struct SDBMExprStorage;
struct SDBMBinaryExprStorage;
struct SDBMDiffExprStorage;
struct SDBMTermExprStorage;
struct SDBMConstantExprStorage;
struct SDBMNegExprStorage;
} // namespace detail
class SDBMConstantExpr;
class SDBMDialect;
class SDBMDimExpr;
class SDBMSymbolExpr;
class SDBMTermExpr;
/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side
/// expression for the SDBM framework. SDBM expressions are a subset of affine
/// expressions supporting low-complexity algorithms for the operations used in
/// loop transformations. In particular, are supported:
/// - constant expressions;
/// - single variables (dimensions and symbols) with +1 or -1 coefficient;
/// - stripe expressions: "x # C", where "x" is a single variable or another
/// stripe expression, "#" is the stripe operator, and "C" is a constant
/// expression; "#" is defined as x - x mod C.
/// - sum expressions between single variable/stripe expressions and constant
/// expressions;
/// - difference expressions between single variable/stripe expressions.
/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing
/// and operating on SDBM expressions. For example, it requires the LHS of a
/// sum expression to be a single variable or a stripe expression. These
/// restrictions are intended to force the caller to perform the necessary
/// simplifications to stay within the SDBM domain, because SDBM expressions do
/// not combine in more cases than they do. This choice may be reconsidered in
/// the future.
///
/// SDBM expressions are grouped into the following structure
/// - expression
/// - varying
/// - direct
/// - sum <- (term, constant)
/// - term
/// - symbol
/// - dimension
/// - stripe <- (direct, constant)
/// - negation <- (direct)
/// - difference <- (direct, term)
/// - constant
/// The notation <- (...) denotes the types of subexpressions a compound
/// expression can combine. The tree of subexpressions essentially imposes the
/// following canonicalization rules:
/// - constants are always folded;
/// - constants can only appear on the RHS of an expression;
/// - double negation must be elided;
/// - an additive constant term is only allowed in a sum expression, and
/// should be sunk into the nearest such expression in the tree;
/// - zero constant expression can only appear at the top level.
///
/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
/// an MLIRContext, and should be used by-value. They are uniqued in the
/// MLIRContext and immortal.
class SDBMExpr {
public:
using ImplType = detail::SDBMExprStorage;
SDBMExpr() : impl(nullptr) {}
/* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {}
/// SDBM expressions are thin wrappers around a unique'ed immutable pointer,
/// which makes them trivially assignable and trivially copyable.
SDBMExpr(const SDBMExpr &) = default;
SDBMExpr &operator=(const SDBMExpr &) = default;
/// SDBM expressions can be compared straight-forwardly.
bool operator==(const SDBMExpr &other) const { return impl == other.impl; }
bool operator!=(const SDBMExpr &other) const { return !(*this == other); }
/// SDBM expressions are convertible to `bool`: null expressions are converted
/// to false, non-null expressions are converted to true.
explicit operator bool() const { return impl != nullptr; }
bool operator!() const { return !static_cast<bool>(*this); }
/// Negate the given SDBM expression.
SDBMExpr operator-();
/// Prints the SDBM expression.
void print(raw_ostream &os) const;
void dump() const;
/// LLVM-style casts.
template <typename U> bool isa() const { return U::isClassFor(*this); }
template <typename U> U dyn_cast() const {
if (!isa<U>())
return {};
return U(const_cast<SDBMExpr *>(this)->impl);
}
template <typename U> U cast() const {
assert(isa<U>() && "cast to incorrect subtype");
return U(const_cast<SDBMExpr *>(this)->impl);
}
/// Support for LLVM hashing.
::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); }
/// Returns the kind of the SDBM expression.
SDBMExprKind getKind() const;
/// Returns the MLIR context in which this expression lives.
MLIRContext *getContext() const;
/// Returns the SDBM dialect instance.
SDBMDialect *getDialect() const;
/// Convert the SDBM expression into an Affine expression. This always
/// succeeds because SDBM are a subset of affine.
AffineExpr getAsAffineExpr() const;
/// Try constructing an SDBM expression from the given affine expression.
/// This may fail if the affine expression is not representable as SDBM, in
/// which case llvm::None is returned. The conversion procedure recognizes
/// (nested) multiplicative ((x floordiv B) * B) and additive (x - x mod B)
/// patterns for the stripe expression.
static Optional<SDBMExpr> tryConvertAffineExpr(AffineExpr affine);
protected:
ImplType *impl;
};
/// SDBM constant expression, wraps a 64-bit integer.
class SDBMConstantExpr : public SDBMExpr {
public:
using ImplType = detail::SDBMConstantExprStorage;
using SDBMExpr::SDBMExpr;
/// Obtain or create a constant expression unique'ed in the given dialect
/// (which belongs to a context).
static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Constant;
}
int64_t getValue() const;
};
/// SDBM varying expression can be one of:
/// - input variable expression;
/// - stripe expression;
/// - negation (product with -1) of either of the above.
/// - sum of a varying and a constant expression
/// - difference between varying expressions
class SDBMVaryingExpr : public SDBMExpr {
public:
using ImplType = detail::SDBMExprStorage;
using SDBMExpr::SDBMExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId ||
expr.getKind() == SDBMExprKind::Neg ||
expr.getKind() == SDBMExprKind::Stripe ||
expr.getKind() == SDBMExprKind::Add ||
expr.getKind() == SDBMExprKind::Diff;
}
};
/// SDBM direct expression includes exactly one variable (symbol or dimension),
/// which is not negated in the expression. It can be one of:
/// - term expression;
/// - sum expression.
class SDBMDirectExpr : public SDBMVaryingExpr {
public:
using SDBMVaryingExpr::SDBMVaryingExpr;
/// If this is a sum expression, return its variable part, otherwise return
/// self.
SDBMTermExpr getTerm();
/// If this is a sum expression, return its constant part, otherwise return 0.
int64_t getConstant();
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId ||
expr.getKind() == SDBMExprKind::Stripe ||
expr.getKind() == SDBMExprKind::Add;
}
};
/// SDBM term expression can be one of:
/// - single variable expression;
/// - stripe expression.
/// Stripe expressions are treated as terms since, in the SDBM domain, they are
/// attached to temporary variables and can appear anywhere a variable can.
class SDBMTermExpr : public SDBMDirectExpr {
public:
using SDBMDirectExpr::SDBMDirectExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId ||
expr.getKind() == SDBMExprKind::Stripe;
}
};
/// SDBM sum expression. LHS is a term expression and RHS is a constant.
class SDBMSumExpr : public SDBMDirectExpr {
public:
using ImplType = detail::SDBMBinaryExprStorage;
using SDBMDirectExpr::SDBMDirectExpr;
/// Obtain or create a sum expression unique'ed in the given context.
static SDBMSumExpr get(SDBMTermExpr lhs, SDBMConstantExpr rhs);
static bool isClassFor(const SDBMExpr &expr) {
SDBMExprKind kind = expr.getKind();
return kind == SDBMExprKind::Add;
}
SDBMTermExpr getLHS() const;
SDBMConstantExpr getRHS() const;
};
/// SDBM difference expression. LHS is a direct expression, i.e. it may be a
/// sum of a term and a constant. RHS is a term expression. Thus the
/// expression (t1 - t2 + C) with term expressions t1,t2 is represented as
/// diff(sum(t1, C), t2)
/// and it is possible to extract the constant factor without negating it.
class SDBMDiffExpr : public SDBMVaryingExpr {
public:
using ImplType = detail::SDBMDiffExprStorage;
using SDBMVaryingExpr::SDBMVaryingExpr;
/// Obtain or create a difference expression unique'ed in the given context.
static SDBMDiffExpr get(SDBMDirectExpr lhs, SDBMTermExpr rhs);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Diff;
}
SDBMDirectExpr getLHS() const;
SDBMTermExpr getRHS() const;
};
/// SDBM stripe expression "x # C" where "x" is a term expression, "C" is a
/// constant expression and "#" is the stripe operator defined as:
/// x # C = x - x mod C.
class SDBMStripeExpr : public SDBMTermExpr {
public:
using ImplType = detail::SDBMBinaryExprStorage;
using SDBMTermExpr::SDBMTermExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Stripe;
}
static SDBMStripeExpr get(SDBMDirectExpr var, SDBMConstantExpr stripeFactor);
SDBMDirectExpr getLHS() const;
SDBMConstantExpr getStripeFactor() const;
};
/// SDBM "input" variable expression can be either a dimension identifier or
/// a symbol identifier. When used to define SDBM functions, dimensions are
/// interpreted as function arguments while symbols are treated as unknown but
/// constant values, hence the name.
class SDBMInputExpr : public SDBMTermExpr {
public:
using ImplType = detail::SDBMTermExprStorage;
using SDBMTermExpr::SDBMTermExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId;
}
unsigned getPosition() const;
};
/// SDBM dimension expression. Dimensions correspond to function arguments
/// when defining functions using SDBM expressions.
class SDBMDimExpr : public SDBMInputExpr {
public:
using ImplType = detail::SDBMTermExprStorage;
using SDBMInputExpr::SDBMInputExpr;
/// Obtain or create a dimension expression unique'ed in the given dialect
/// (which belongs to a context).
static SDBMDimExpr get(SDBMDialect *dialect, unsigned position);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId;
}
};
/// SDBM symbol expression. Symbols correspond to symbolic constants when
/// defining functions using SDBM expressions.
class SDBMSymbolExpr : public SDBMInputExpr {
public:
using ImplType = detail::SDBMTermExprStorage;
using SDBMInputExpr::SDBMInputExpr;
/// Obtain or create a symbol expression unique'ed in the given dialect (which
/// belongs to a context).
static SDBMSymbolExpr get(SDBMDialect *dialect, unsigned position);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::SymbolId;
}
};
/// Negation of an SDBM variable expression. Equivalent to multiplying the
/// expression with -1 (SDBM does not support other coefficients that 1 and -1).
class SDBMNegExpr : public SDBMVaryingExpr {
public:
using ImplType = detail::SDBMNegExprStorage;
using SDBMVaryingExpr::SDBMVaryingExpr;
/// Obtain or create a negation expression unique'ed in the given context.
static SDBMNegExpr get(SDBMDirectExpr var);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Neg;
}
SDBMDirectExpr getVar() const;
};
/// A visitor class for SDBM expressions. Calls the kind-specific function
/// depending on the kind of expression it visits.
template <typename Derived, typename Result = void> class SDBMVisitor {
public:
/// Visit the given SDBM expression, dispatching to kind-specific functions.
Result visit(SDBMExpr expr) {
auto *derived = static_cast<Derived *>(this);
switch (expr.getKind()) {
case SDBMExprKind::Add:
case SDBMExprKind::Diff:
case SDBMExprKind::DimId:
case SDBMExprKind::SymbolId:
case SDBMExprKind::Neg:
case SDBMExprKind::Stripe:
return derived->visitVarying(expr.cast<SDBMVaryingExpr>());
case SDBMExprKind::Constant:
return derived->visitConstant(expr.cast<SDBMConstantExpr>());
}
llvm_unreachable("unsupported SDBM expression kind");
}
/// Traverse the SDBM expression tree calling `visit` on each node
/// in depth-first preorder.
void walkPreorder(SDBMExpr expr) { return walk</*isPreorder=*/true>(expr); }
/// Traverse the SDBM expression tree calling `visit` on each node in
/// depth-first postorder.
void walkPostorder(SDBMExpr expr) { return walk</*isPreorder=*/false>(expr); }
protected:
/// Default visitors do nothing.
void visitSum(SDBMSumExpr) {}
void visitDiff(SDBMDiffExpr) {}
void visitStripe(SDBMStripeExpr) {}
void visitDim(SDBMDimExpr) {}
void visitSymbol(SDBMSymbolExpr) {}
void visitNeg(SDBMNegExpr) {}
void visitConstant(SDBMConstantExpr) {}
/// Default implementation of visitDirect dispatches to the dedicated for sums
/// or delegates to visitTerm for the other expression kinds. Concrete
/// visitors can overload it.
Result visitDirect(SDBMDirectExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (auto sum = expr.dyn_cast<SDBMSumExpr>())
return derived->visitSum(sum);
else
return derived->visitTerm(expr.cast<SDBMTermExpr>());
}
/// Default implementation of visitTerm dispatches to the special functions
/// for stripes and other variables. Concrete visitors can override it.
Result visitTerm(SDBMTermExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (expr.getKind() == SDBMExprKind::Stripe)
return derived->visitStripe(expr.cast<SDBMStripeExpr>());
else
return derived->visitInput(expr.cast<SDBMInputExpr>());
}
/// Default implementation of visitInput dispatches to the special
/// functions for dimensions or symbols. Concrete visitors can override it to
/// visit all variables instead.
Result visitInput(SDBMInputExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (expr.getKind() == SDBMExprKind::DimId)
return derived->visitDim(expr.cast<SDBMDimExpr>());
else
return derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
}
/// Default implementation of visitVarying dispatches to the special
/// functions for variables and negations thereof. Concrete visitors can
/// override it to visit all variables and negations instead.
Result visitVarying(SDBMVaryingExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (auto var = expr.dyn_cast<SDBMDirectExpr>())
return derived->visitDirect(var);
else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
return derived->visitNeg(neg);
else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
return derived->visitDiff(diff);
llvm_unreachable("unhandled subtype of varying SDBM expression");
}
template <bool isPreorder> void walk(SDBMExpr expr) {
if (isPreorder)
visit(expr);
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
walk<isPreorder>(sumExpr.getLHS());
walk<isPreorder>(sumExpr.getRHS());
} else if (auto diffExpr = expr.dyn_cast<SDBMDiffExpr>()) {
walk<isPreorder>(diffExpr.getLHS());
walk<isPreorder>(diffExpr.getRHS());
} else if (auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>()) {
walk<isPreorder>(stripeExpr.getLHS());
walk<isPreorder>(stripeExpr.getStripeFactor());
} else if (auto negExpr = expr.dyn_cast<SDBMNegExpr>()) {
walk<isPreorder>(negExpr.getVar());
}
if (!isPreorder)
visit(expr);
}
};
/// Overloaded arithmetic operators for SDBM expressions asserting that their
/// arguments have the proper SDBM expression subtype. Perform canonicalization
/// and constant folding on these expressions.
namespace ops_assertions {
/// Add two SDBM expressions. At least one of the expressions must be a
/// constant or a negation, but both expressions cannot be negations
/// simultaneously.
SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs);
inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) {
return lhs + SDBMConstantExpr::get(lhs.getDialect(), rhs);
}
inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) {
return SDBMConstantExpr::get(rhs.getDialect(), lhs) + rhs;
}
/// Subtract an SDBM expression from another SDBM expression. Both expressions
/// must not be difference expressions.
SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs);
inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) {
return lhs - SDBMConstantExpr::get(lhs.getDialect(), rhs);
}
inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) {
return SDBMConstantExpr::get(rhs.getDialect(), lhs) - rhs;
}
/// Construct a stripe expression from a positive expression and a positive
/// constant stripe factor.
SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor);
inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) {
return stripe(expr, SDBMConstantExpr::get(expr.getDialect(), factor));
}
} // namespace ops_assertions
} // end namespace mlir
namespace llvm {
// SDBMExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMExpr> {
static mlir::SDBMExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static mlir::SDBMExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::SDBMExpr expr) {
return expr.hash_value();
}
static bool isEqual(mlir::SDBMExpr lhs, mlir::SDBMExpr rhs) {
return lhs == rhs;
}
};
// SDBMDirectExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMDirectExpr> {
static mlir::SDBMDirectExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMDirectExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static mlir::SDBMDirectExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMDirectExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::SDBMDirectExpr expr) {
return expr.hash_value();
}
static bool isEqual(mlir::SDBMDirectExpr lhs, mlir::SDBMDirectExpr rhs) {
return lhs == rhs;
}
};
// SDBMTermExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMTermExpr> {
static mlir::SDBMTermExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMTermExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static mlir::SDBMTermExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMTermExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::SDBMTermExpr expr) {
return expr.hash_value();
}
static bool isEqual(mlir::SDBMTermExpr lhs, mlir::SDBMTermExpr rhs) {
return lhs == rhs;
}
};
// SDBMConstantExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMConstantExpr> {
static mlir::SDBMConstantExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMConstantExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static mlir::SDBMConstantExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMConstantExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::SDBMConstantExpr expr) {
return expr.hash_value();
}
static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) {
return lhs == rhs;
}
};
} // namespace llvm
#endif // MLIR_DIALECT_SDBM_SDBMEXPR_H

View File

@ -35,7 +35,6 @@
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SDBM/SDBMDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@ -75,7 +74,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
vector::VectorDialect,
NVVM::NVVMDialect,
ROCDL::ROCDLDialect,
SDBMDialect,
shape::ShapeDialect,
sparse_tensor::SparseTensorDialect,
tensor::TensorDialect,

View File

@ -17,7 +17,6 @@ add_subdirectory(PDL)
add_subdirectory(PDLInterp)
add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(SDBM)
add_subdirectory(Shape)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)

View File

@ -1,11 +0,0 @@
add_mlir_dialect_library(MLIRSDBM
SDBM.cpp
SDBMDialect.cpp
SDBMExpr.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SDBM
LINK_LIBS PUBLIC
MLIRIR
)

View File

@ -1,551 +0,0 @@
//===- SDBM.cpp - MLIR SDBM implementation --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined
// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SDBM/SDBM.h"
#include "mlir/Dialect/SDBM/SDBMExpr.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
// Helper function for SDBM construction that collects information necessary to
// start building an SDBM in one sweep. In particular, it records the largest
// position of a dimension in `dim`, that of a symbol in `symbol` as well as
// collects all unique stripe expressions in `stripes`. Uses SetVector to
// ensure these expressions always have the same order.
static void collectSDBMBuildInfo(SDBMExpr expr, int &dim, int &symbol,
llvm::SmallSetVector<SDBMExpr, 8> &stripes) {
struct Visitor : public SDBMVisitor<Visitor> {
void visitDim(SDBMDimExpr dimExpr) {
int p = dimExpr.getPosition();
if (p > maxDimPosition)
maxDimPosition = p;
}
void visitSymbol(SDBMSymbolExpr symbExpr) {
int p = symbExpr.getPosition();
if (p > maxSymbPosition)
maxSymbPosition = p;
}
void visitStripe(SDBMStripeExpr stripeExpr) { stripes.insert(stripeExpr); }
Visitor(llvm::SmallSetVector<SDBMExpr, 8> &stripes) : stripes(stripes) {}
int maxDimPosition = -1;
int maxSymbPosition = -1;
llvm::SmallSetVector<SDBMExpr, 8> &stripes;
};
Visitor visitor(stripes);
visitor.walkPostorder(expr);
dim = std::max(dim, visitor.maxDimPosition);
symbol = std::max(symbol, visitor.maxSymbPosition);
}
namespace {
// Utility class for SDBMBuilder. Represents a value that can be inserted in
// the SDB matrix that corresponds to "v0 - v1 + C <= 0", where v0 and v1 is
// any combination of the positive and negative positions. Since multiple
// variables can be declared equal to the same stripe expression, the
// constraints on this expression must be reflected to all these variables. For
// example, if
// d0 = s0 # 42
// d1 = s0 # 42
// d2 = s1 # 2
// d3 = s1 # 2
// the constraint
// s0 # 42 - s1 # 2 <= C
// should be reflected in the DB matrix as
// d0 - d2 <= C
// d1 - d2 <= C
// d0 - d3 <= C
// d1 - d3 <= C
// since the DB matrix has no knowledge of the transitive equality between d0,
// d1 and s0 # 42 as well as between d2, d3 and s1 # 2. This knowledge can be
// obtained by computing a transitive closure, which is impossible until the
// DBM is actually built.
struct SDBMBuilderResult {
// Positions in the matrix of the variables taken with the "+" sign in the
// difference expression, 0 if it is a constant rather than a variable.
SmallVector<unsigned, 2> positivePos;
// Positions in the matrix of the variables taken with the "-" sign in the
// difference expression, 0 if it is a constant rather than a variable.
SmallVector<unsigned, 2> negativePos;
// Constant value in the difference expression.
int64_t value = 0;
};
// Visitor for building an SDBM from SDBM expressions. After traversing an SDBM
// expression, produces an update to the SDB matrix specifying the positions in
// the matrix and the negated value that should be stored. Both the positive
// and the negative positions may be lists of indices in cases where multiple
// variables are equal to the same stripe expression. In such cases, the update
// applies to the cross product of positions because elements involved in the
// update are (transitively) equal and should have the same constraints, but we
// may not have an explicit equality for them.
struct SDBMBuilder : public SDBMVisitor<SDBMBuilder, SDBMBuilderResult> {
public:
// A difference expression produces both the positive and the negative
// coordinate in the matrix, recursively traversing the LHS and the RHS. The
// value is the difference between values obtained from LHS and RHS.
SDBMBuilderResult visitDiff(SDBMDiffExpr diffExpr) {
auto lhs = visit(diffExpr.getLHS());
auto rhs = visit(diffExpr.getRHS());
assert(lhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 &&
"unexpected negative expression in a difference expression");
assert(rhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 &&
"unexpected negative expression in a difference expression");
SDBMBuilderResult result;
result.positivePos = lhs.positivePos;
result.negativePos = rhs.positivePos;
result.value = lhs.value - rhs.value;
return result;
}
// An input expression is always taken with the "+" sign and therefore
// produces a positive coordinate keeping the negative coordinate zero for an
// eventual constant.
SDBMBuilderResult visitInput(SDBMInputExpr expr) {
SDBMBuilderResult r;
r.positivePos.push_back(linearPosition(expr));
r.negativePos.push_back(0);
return r;
}
// A stripe expression is always equal to one or more variables, which may be
// temporaries, and appears with a "+" sign in the SDBM expression tree. Take
// the positions of the corresponding variables as positive coordinates.
SDBMBuilderResult visitStripe(SDBMStripeExpr expr) {
SDBMBuilderResult r;
assert(pointExprToStripe.count(expr));
r.positivePos = pointExprToStripe[expr];
r.negativePos.push_back(0);
return r;
}
// A constant expression has both coordinates at zero.
SDBMBuilderResult visitConstant(SDBMConstantExpr expr) {
SDBMBuilderResult r;
r.positivePos.push_back(0);
r.negativePos.push_back(0);
r.value = expr.getValue();
return r;
}
// A negation expression swaps the positive and the negative coordinates
// and also negates the constant value.
SDBMBuilderResult visitNeg(SDBMNegExpr expr) {
SDBMBuilderResult result = visit(expr.getVar());
std::swap(result.positivePos, result.negativePos);
result.value = -result.value;
return result;
}
// The RHS of a sum expression must be a constant and therefore must have both
// positive and negative coordinates at zero. Take the sum of the values
// between LHS and RHS and keep LHS coordinates.
SDBMBuilderResult visitSum(SDBMSumExpr expr) {
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
for (auto pos : rhs.negativePos) {
(void)pos;
assert(pos == 0 && "unexpected variable on the RHS of SDBM sum");
}
for (auto pos : rhs.positivePos) {
(void)pos;
assert(pos == 0 && "unexpected variable on the RHS of SDBM sum");
}
lhs.value += rhs.value;
return lhs;
}
SDBMBuilder(DenseMap<SDBMExpr, SmallVector<unsigned, 2>> &pointExprToStripe,
function_ref<unsigned(SDBMInputExpr)> callback)
: pointExprToStripe(pointExprToStripe), linearPosition(callback) {}
DenseMap<SDBMExpr, SmallVector<unsigned, 2>> &pointExprToStripe;
function_ref<unsigned(SDBMInputExpr)> linearPosition;
};
} // namespace
SDBM SDBM::get(ArrayRef<SDBMExpr> inequalities, ArrayRef<SDBMExpr> equalities) {
SDBM result;
// TODO: consider detecting equalities in the list of inequalities.
// This is potentially expensive and requires to
// - create a list of negated inequalities (may allocate under lock);
// - perform a pairwise comparison of direct and negated inequalities;
// - copy the lists of equalities and inequalities, and move entries between
// them;
// only for the purpose of sparing a temporary variable in cases where an
// implicit equality between a variable and a stripe expression is present in
// the input.
// Do the first sweep over (in)equalities to collect the information necessary
// to allocate the SDB matrix (number of dimensions, symbol and temporary
// variables required for stripe expressions).
llvm::SmallSetVector<SDBMExpr, 8> stripes;
int maxDim = -1;
int maxSymbol = -1;
for (auto expr : inequalities)
collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes);
for (auto expr : equalities)
collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes);
// Indexing of dimensions starts with 0, obtain the number of dimensions by
// incrementing the maximal position of the dimension seen in expressions.
result.numDims = maxDim + 1;
result.numSymbols = maxSymbol + 1;
result.numTemporaries = 0;
// Helper function that returns the position of the variable represented by
// an SDBM input expression.
auto linearPosition = [result](SDBMInputExpr expr) {
if (expr.isa<SDBMDimExpr>())
return result.getDimPosition(expr.getPosition());
return result.getSymbolPosition(expr.getPosition());
};
// Check if some stripe expressions are equal to another variable. In
// particular, look for the equalities of the form
// d0 - stripe-expression = 0, or
// stripe-expression - d0 = 0.
// There may be multiple variables that are equal to the same stripe
// expression. Keep track of those in pointExprToStripe.
// There may also be multiple stripe expressions equal to the same variable.
// Introduce a temporary variable for each of those.
DenseMap<SDBMExpr, SmallVector<unsigned, 2>> pointExprToStripe;
unsigned numTemporaries = 0;
auto updateStripePointMaps = [&numTemporaries, &result, &pointExprToStripe,
linearPosition](SDBMInputExpr input,
SDBMExpr expr) {
unsigned position = linearPosition(input);
if (result.stripeToPoint.count(position) &&
result.stripeToPoint[position] != expr) {
position = result.getNumVariables() + numTemporaries++;
}
pointExprToStripe[expr].push_back(position);
result.stripeToPoint.insert(std::make_pair(position, expr));
};
for (auto eq : equalities) {
auto diffExpr = eq.dyn_cast<SDBMDiffExpr>();
if (!diffExpr)
continue;
auto lhs = diffExpr.getLHS();
auto rhs = diffExpr.getRHS();
auto lhsInput = lhs.dyn_cast<SDBMInputExpr>();
auto rhsInput = rhs.dyn_cast<SDBMInputExpr>();
if (lhsInput && stripes.count(rhs))
updateStripePointMaps(lhsInput, rhs);
if (rhsInput && stripes.count(lhs))
updateStripePointMaps(rhsInput, lhs);
}
// Assign the remaining stripe expressions to temporary variables. These
// expressions are the ones that could not be associated with an existing
// variable in the previous step.
for (auto expr : stripes) {
if (pointExprToStripe.count(expr))
continue;
unsigned position = result.getNumVariables() + numTemporaries++;
pointExprToStripe[expr].push_back(position);
result.stripeToPoint.insert(std::make_pair(position, expr));
}
// Create the DBM matrix, initialized to infinity values for the least tight
// possible bound (x - y <= infinity is always true).
result.numTemporaries = numTemporaries;
result.matrix.resize(result.getNumVariables() * result.getNumVariables(),
IntInfty::infinity());
SDBMBuilder builder(pointExprToStripe, linearPosition);
// Only keep the tightest constraint. Since we transform everything into
// less-than-or-equals-to inequalities, keep the smallest constant. For
// example, if we have d0 - d1 <= 42 and d0 - d1 <= 2, we keep the latter.
// Note that the input expressions are in the shape of d0 - d1 + -42 <= 0
// so we negate the value before storing it.
// In case where the positive and the negative positions are equal, the
// corresponding expression has the form d0 - d0 + -42 <= 0. If the constant
// value is positive, the set defined by SDBM is trivially empty. We store
// this value anyway and continue processing to maintain the correspondence
// between the matrix form and the list-of-SDBMExpr form.
// TODO: we may want to reconsider this once we have canonicalization
// or simplification in place
auto updateMatrix = [](SDBM &sdbm, const SDBMBuilderResult &r) {
for (auto positivePos : r.positivePos) {
for (auto negativePos : r.negativePos) {
auto &m = sdbm.at(negativePos, positivePos);
m = m < -r.value ? m : -r.value;
}
}
};
// Do the second sweep on (in)equalities, updating the SDB matrix to reflect
// the constraints.
for (auto ineq : inequalities)
updateMatrix(result, builder.visit(ineq));
// An equality f(x) = 0 is represented as a pair of inequalities {f(x) >= 0;
// f(x) <= 0} or, alternatively, {-f(x) <= 0 and f(x) <= 0}.
for (auto eq : equalities) {
updateMatrix(result, builder.visit(eq));
updateMatrix(result, builder.visit(-eq));
}
// Add the inequalities induced by stripe equalities.
// t = x # C => t <= x <= t + C - 1
// which is equivalent to
// {t - x <= 0;
// x - t - (C - 1) <= 0}.
for (const auto &pair : result.stripeToPoint) {
auto stripe = pair.second.cast<SDBMStripeExpr>();
SDBMBuilderResult update = builder.visit(stripe.getLHS());
assert(update.negativePos.size() == 1 && update.negativePos[0] == 0 &&
"unexpected negated variable in stripe expression");
assert(update.value == 0 &&
"unexpected non-zero value in stripe expression");
update.negativePos.clear();
update.negativePos.push_back(pair.first);
update.value = -(stripe.getStripeFactor().getValue() - 1);
updateMatrix(result, update);
std::swap(update.negativePos, update.positivePos);
update.value = 0;
updateMatrix(result, update);
}
return result;
}
// Given a row and a column position in the square DBM, insert one equality
// or up to two inequalities that correspond the entries (col, row) and (row,
// col) in the DBM. `rowExpr` and `colExpr` contain the expressions such that
// colExpr - rowExpr <= V where V is the value at (row, col) in the DBM.
// If one of the expressions is derived from another using a stripe operation,
// check if the inequalities induced by the stripe operation subsume the
// inequalities defined in the DBM and if so, elide these inequalities.
void SDBM::convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr,
SDBMTermExpr colExpr,
SmallVectorImpl<SDBMExpr> &inequalities,
SmallVectorImpl<SDBMExpr> &equalities) {
using ops_assertions::operator+;
using ops_assertions::operator-;
auto diffIJValue = at(col, row);
auto diffJIValue = at(row, col);
// If symmetric entries are opposite, the corresponding expressions are equal.
if (diffIJValue.isFinite() &&
diffIJValue.getValue() == -diffJIValue.getValue()) {
equalities.push_back(rowExpr - colExpr - diffIJValue.getValue());
return;
}
// Given an inequality x0 - x1 <= A, check if x0 is a stripe variable derived
// from x1: x0 = x1 # B. If so, it would imply the constraints
// x0 <= x1 <= x0 + (B - 1) <=> x0 - x1 <= 0 and x1 - x0 <= (B - 1).
// Therefore, if A >= 0, this inequality is subsumed by that implied
// by the stripe equality and thus can be elided.
// Similarly, check if x1 is a stripe variable derived from x0: x1 = x0 # C.
// If so, it would imply the constraints x1 <= x0 <= x1 + (C - 1) <=>
// <=> x1 - x0 <= 0 and x0 - x1 <= (C - 1). Therefore, if A >= (C - 1), this
// inequality can be elided.
//
// Note: x0 and x1 may be a stripe expressions themselves, we rely on stripe
// expressions being stored without temporaries on the RHS and being passed
// into this function as is.
auto canElide = [this](unsigned x0, unsigned x1, SDBMExpr x0Expr,
SDBMExpr x1Expr, int64_t value) {
if (stripeToPoint.count(x0)) {
auto stripe = stripeToPoint[x0].cast<SDBMStripeExpr>();
SDBMDirectExpr var = stripe.getLHS();
if (x1Expr == var && value >= 0)
return true;
}
if (stripeToPoint.count(x1)) {
auto stripe = stripeToPoint[x1].cast<SDBMStripeExpr>();
SDBMDirectExpr var = stripe.getLHS();
if (x0Expr == var && value >= stripe.getStripeFactor().getValue() - 1)
return true;
}
return false;
};
// Check row - col.
if (diffIJValue.isFinite() &&
!canElide(row, col, rowExpr, colExpr, diffIJValue.getValue())) {
inequalities.push_back(rowExpr - colExpr - diffIJValue.getValue());
}
// Check col - row.
if (diffJIValue.isFinite() &&
!canElide(col, row, colExpr, rowExpr, diffJIValue.getValue())) {
inequalities.push_back(colExpr - rowExpr - diffJIValue.getValue());
}
}
// The values on the main diagonal correspond to the upper bound on the
// difference between a variable and itself: d0 - d0 <= C, or alternatively
// to -C <= 0. Only construct the inequalities when C is negative, which
// are trivially false but necessary for the returned system of inequalities
// to indicate that the set it defines is empty.
void SDBM::convertDBMDiagonalElement(unsigned pos, SDBMTermExpr expr,
SmallVectorImpl<SDBMExpr> &inequalities) {
auto selfDifference = at(pos, pos);
if (selfDifference.isFinite() && selfDifference < 0) {
auto selfDifferenceValueExpr =
SDBMConstantExpr::get(expr.getDialect(), -selfDifference.getValue());
inequalities.push_back(selfDifferenceValueExpr);
}
}
void SDBM::getSDBMExpressions(SDBMDialect *dialect,
SmallVectorImpl<SDBMExpr> &inequalities,
SmallVectorImpl<SDBMExpr> &equalities) {
using ops_assertions::operator-;
using ops_assertions::operator+;
// Helper function that creates an SDBMInputExpr given the linearized position
// of variable in the DBM.
auto getInput = [dialect, this](unsigned matrixPos) -> SDBMInputExpr {
if (matrixPos < numDims)
return SDBMDimExpr::get(dialect, matrixPos);
return SDBMSymbolExpr::get(dialect, matrixPos - numDims);
};
// The top-left value corresponds to inequality 0 <= C. If C is negative, the
// set defined by SDBM is trivially empty and we add the constraint -C <= 0 to
// the list of inequalities. Otherwise, the constraint is trivially true and
// we ignore it.
auto difference = at(0, 0);
if (difference.isFinite() && difference < 0) {
inequalities.push_back(
SDBMConstantExpr::get(dialect, -difference.getValue()));
}
// Traverse the segment of the matrix that involves non-temporary variables.
unsigned numTrueVariables = numDims + numSymbols;
for (unsigned i = 0; i < numTrueVariables; ++i) {
// The first row and column represent numerical upper and lower bound on
// each variable. Transform them into inequalities if they are finite.
auto upperBound = at(0, 1 + i);
auto lowerBound = at(1 + i, 0);
auto inputExpr = getInput(i);
if (upperBound.isFinite() &&
upperBound.getValue() == -lowerBound.getValue()) {
equalities.push_back(inputExpr - upperBound.getValue());
} else if (upperBound.isFinite()) {
inequalities.push_back(inputExpr - upperBound.getValue());
} else if (lowerBound.isFinite()) {
inequalities.push_back(-inputExpr - lowerBound.getValue());
}
// Introduce trivially false inequalities if required by diagonal elements.
convertDBMDiagonalElement(1 + i, inputExpr, inequalities);
// Introduce equalities or inequalities between non-temporary variables.
for (unsigned j = 0; j < i; ++j) {
convertDBMElement(1 + i, 1 + j, getInput(i), getInput(j), inequalities,
equalities);
}
}
// Add equalities for stripe expressions that define non-temporary
// variables. Temporary variables will be substituted into their uses and
// should not appear in the resulting equalities.
for (const auto &stripePair : stripeToPoint) {
unsigned position = stripePair.first;
if (position < 1 + numTrueVariables) {
equalities.push_back(getInput(position - 1) - stripePair.second);
}
}
// Add equalities / inequalities involving temporaries by replacing the
// temporaries with stripe expressions that define them.
for (unsigned i = 1 + numTrueVariables, e = getNumVariables(); i < e; ++i) {
// Mixed constraints involving one temporary (j) and one non-temporary (i)
// variable.
for (unsigned j = 0; j < numTrueVariables; ++j) {
convertDBMElement(i, 1 + j, stripeToPoint[i].cast<SDBMStripeExpr>(),
getInput(j), inequalities, equalities);
}
// Constraints involving only temporary variables.
for (unsigned j = 1 + numTrueVariables; j < i; ++j) {
convertDBMElement(i, j, stripeToPoint[i].cast<SDBMStripeExpr>(),
stripeToPoint[j].cast<SDBMStripeExpr>(), inequalities,
equalities);
}
// Introduce trivially false inequalities if required by diagonal elements.
convertDBMDiagonalElement(i, stripeToPoint[i].cast<SDBMStripeExpr>(),
inequalities);
}
}
void SDBM::print(raw_ostream &os) {
unsigned numVariables = getNumVariables();
// Helper function that prints the name of the variable given its linearized
// position in the DBM.
auto getVarName = [this](unsigned matrixPos) -> std::string {
if (matrixPos == 0)
return "cst";
matrixPos -= 1;
if (matrixPos < numDims)
return std::string(llvm::formatv("d{0}", matrixPos));
matrixPos -= numDims;
if (matrixPos < numSymbols)
return std::string(llvm::formatv("s{0}", matrixPos));
matrixPos -= numSymbols;
return std::string(llvm::formatv("t{0}", matrixPos));
};
// Header row.
os << " cst";
for (unsigned i = 1; i < numVariables; ++i) {
os << llvm::formatv(" {0,4}", getVarName(i));
}
os << '\n';
// Data rows.
for (unsigned i = 0; i < numVariables; ++i) {
os << llvm::formatv("{0,-4}", getVarName(i));
for (unsigned j = 0; j < numVariables; ++j) {
IntInfty value = operator()(i, j);
if (!value.isFinite())
os << " inf";
else
os << llvm::formatv(" {0,4}", value.getValue());
}
os << '\n';
}
// Explanation of temporaries.
for (const auto &pair : stripeToPoint) {
os << getVarName(pair.first) << " = ";
pair.second.print(os);
os << '\n';
}
}
void SDBM::dump() { print(llvm::errs()); }

View File

@ -1,23 +0,0 @@
//===- SDBMDialect.cpp - MLIR SDBM Dialect --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SDBM/SDBMDialect.h"
#include "SDBMExprDetail.h"
using namespace mlir;
SDBMDialect::SDBMDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {
uniquer.registerParametricStorageType<detail::SDBMBinaryExprStorage>();
uniquer.registerParametricStorageType<detail::SDBMConstantExprStorage>();
uniquer.registerParametricStorageType<detail::SDBMDiffExprStorage>();
uniquer.registerParametricStorageType<detail::SDBMNegExprStorage>();
uniquer.registerParametricStorageType<detail::SDBMTermExprStorage>();
}
SDBMDialect::~SDBMDialect() = default;

View File

@ -1,732 +0,0 @@
//===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A striped difference-bound matrix (SDBM) expression is a constant expression,
// an identifier, a binary expression with constant RHS and +, stripe operators
// or a difference expression between two identifiers.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SDBM/SDBMExpr.h"
#include "SDBMExprDetail.h"
#include "mlir/Dialect/SDBM/SDBMDialect.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
namespace {
/// A simple compositional matcher for AffineExpr
///
/// Example usage:
///
/// ```c++
/// AffineExprMatcher x, C, m;
/// AffineExprMatcher pattern1 = ((x % C) * m) + x;
/// AffineExprMatcher pattern2 = x + ((x % C) * m);
/// if (pattern1.match(expr) || pattern2.match(expr)) {
/// ...
/// }
/// ```
class AffineExprMatcherStorage;
class AffineExprMatcher {
public:
AffineExprMatcher();
AffineExprMatcher(const AffineExprMatcher &other);
AffineExprMatcher operator+(AffineExprMatcher other) {
return AffineExprMatcher(AffineExprKind::Add, *this, other);
}
AffineExprMatcher operator*(AffineExprMatcher other) {
return AffineExprMatcher(AffineExprKind::Mul, *this, other);
}
AffineExprMatcher floorDiv(AffineExprMatcher other) {
return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other);
}
AffineExprMatcher ceilDiv(AffineExprMatcher other) {
return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other);
}
AffineExprMatcher operator%(AffineExprMatcher other) {
return AffineExprMatcher(AffineExprKind::Mod, *this, other);
}
AffineExpr match(AffineExpr expr);
AffineExpr matched();
Optional<int> getMatchedConstantValue();
private:
AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b);
AffineExprKind kind; // only used to match in binary op cases.
// A shared_ptr allows multiple references to same matcher storage without
// worrying about ownership or dealing with an arena. To be cleaned up if we
// go with this.
std::shared_ptr<AffineExprMatcherStorage> storage;
};
class AffineExprMatcherStorage {
public:
AffineExprMatcherStorage() {}
AffineExprMatcherStorage(const AffineExprMatcherStorage &other)
: subExprs(other.subExprs.begin(), other.subExprs.end()),
matched(other.matched) {}
AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)
: subExprs(exprs.begin(), exprs.end()) {}
AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
: subExprs({a, b}) {}
SmallVector<AffineExprMatcher, 0> subExprs;
AffineExpr matched;
};
} // namespace
AffineExprMatcher::AffineExprMatcher()
: kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {}
AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other)
: kind(other.kind), storage(other.storage) {}
Optional<int> AffineExprMatcher::getMatchedConstantValue() {
if (auto cst = storage->matched.dyn_cast<AffineConstantExpr>())
return cst.getValue();
return None;
}
AffineExpr AffineExprMatcher::match(AffineExpr expr) {
if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) {
if (storage->matched)
if (storage->matched != expr)
return AffineExpr();
storage->matched = expr;
return storage->matched;
}
if (kind != expr.getKind()) {
return AffineExpr();
}
if (auto bin = expr.dyn_cast<AffineBinaryOpExpr>()) {
if (!storage->subExprs.empty() &&
!storage->subExprs[0].match(bin.getLHS())) {
return AffineExpr();
}
if (!storage->subExprs.empty() &&
!storage->subExprs[1].match(bin.getRHS())) {
return AffineExpr();
}
if (storage->matched)
if (storage->matched != expr)
return AffineExpr();
storage->matched = expr;
return storage->matched;
}
llvm_unreachable("binary expected");
}
AffineExpr AffineExprMatcher::matched() { return storage->matched; }
AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
AffineExprMatcher b)
: kind(k), storage(new AffineExprMatcherStorage(a, b)) {
storage->subExprs.push_back(a);
storage->subExprs.push_back(b);
}
//===----------------------------------------------------------------------===//
// SDBMExpr
//===----------------------------------------------------------------------===//
SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
MLIRContext *SDBMExpr::getContext() const {
return impl->dialect->getContext();
}
SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; }
void SDBMExpr::print(raw_ostream &os) const {
struct Printer : public SDBMVisitor<Printer> {
Printer(raw_ostream &ostream) : prn(ostream) {}
void visitSum(SDBMSumExpr expr) {
visit(expr.getLHS());
prn << " + ";
visit(expr.getRHS());
}
void visitDiff(SDBMDiffExpr expr) {
visit(expr.getLHS());
prn << " - ";
visit(expr.getRHS());
}
void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
void visitStripe(SDBMStripeExpr expr) {
SDBMDirectExpr lhs = expr.getLHS();
bool isTerm = lhs.isa<SDBMTermExpr>();
if (!isTerm)
prn << '(';
visit(lhs);
if (!isTerm)
prn << ')';
prn << " # ";
visitConstant(expr.getStripeFactor());
}
void visitNeg(SDBMNegExpr expr) {
bool isSum = expr.getVar().isa<SDBMSumExpr>();
prn << '-';
if (isSum)
prn << '(';
visit(expr.getVar());
if (isSum)
prn << ')';
}
void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
raw_ostream &prn;
};
Printer printer(os);
printer.visit(*this);
}
void SDBMExpr::dump() const {
print(llvm::errs());
llvm::errs() << '\n';
}
namespace {
// Helper class to perform negation of an SDBM expression.
struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
// Any term expression is wrapped into a negation expression.
// -(x) = -x
SDBMExpr visitDirect(SDBMDirectExpr expr) { return SDBMNegExpr::get(expr); }
// A negation expression is unwrapped.
// -(-x) = x
SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
// The value of the constant is negated.
SDBMExpr visitConstant(SDBMConstantExpr expr) {
return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
}
// Terms of a difference are interchanged. Since only the LHS of a diff
// expression is allowed to be a sum with a constant, we need to recreate the
// sum with the negated value:
// -((x + C) - y) = (y - C) - x.
SDBMExpr visitDiff(SDBMDiffExpr expr) {
// If the LHS is just a term, we can do straightforward interchange.
if (auto term = expr.getLHS().dyn_cast<SDBMTermExpr>())
return SDBMDiffExpr::get(expr.getRHS(), term);
auto sum = expr.getLHS().cast<SDBMSumExpr>();
auto cst = visitConstant(sum.getRHS()).cast<SDBMConstantExpr>();
return SDBMDiffExpr::get(SDBMSumExpr::get(expr.getRHS(), cst),
sum.getLHS());
}
};
} // namespace
SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
//===----------------------------------------------------------------------===//
// SDBMSumExpr
//===----------------------------------------------------------------------===//
SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
assert(lhs && "expected SDBM variable expression");
assert(rhs && "expected SDBM constant");
// If LHS of a sum is another sum, fold the constant RHS parts.
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
lhs = lhsSum.getLHS();
rhs = SDBMConstantExpr::get(rhs.getDialect(),
rhs.getValue() + lhsSum.getRHS().getValue());
}
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
return uniquer.get<detail::SDBMBinaryExprStorage>(
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
}
SDBMTermExpr SDBMSumExpr::getLHS() const {
return static_cast<ImplType *>(impl)->lhs.cast<SDBMTermExpr>();
}
SDBMConstantExpr SDBMSumExpr::getRHS() const {
return static_cast<ImplType *>(impl)->rhs;
}
AffineExpr SDBMExpr::getAsAffineExpr() const {
struct Converter : public SDBMVisitor<Converter, AffineExpr> {
AffineExpr visitSum(SDBMSumExpr expr) {
AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
return lhs + rhs;
}
AffineExpr visitStripe(SDBMStripeExpr expr) {
AffineExpr lhs = visit(expr.getLHS()),
rhs = visit(expr.getStripeFactor());
return lhs - (lhs % rhs);
}
AffineExpr visitDiff(SDBMDiffExpr expr) {
AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
return lhs - rhs;
}
AffineExpr visitDim(SDBMDimExpr expr) {
return getAffineDimExpr(expr.getPosition(), expr.getContext());
}
AffineExpr visitSymbol(SDBMSymbolExpr expr) {
return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
}
AffineExpr visitNeg(SDBMNegExpr expr) {
return getAffineBinaryOpExpr(AffineExprKind::Mul,
getAffineConstantExpr(-1, expr.getContext()),
visit(expr.getVar()));
}
AffineExpr visitConstant(SDBMConstantExpr expr) {
return getAffineConstantExpr(expr.getValue(), expr.getContext());
}
} converter;
return converter.visit(*this);
}
// Given a direct expression `expr`, add the given constant to it and pass the
// resulting expression to `builder` before returning its result. If the
// expression is already a sum expression, update its constant and extract the
// LHS if the constant becomes zero. Otherwise, construct a sum expression.
template <typename Result>
static Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant,
bool negated,
function_ref<Result(SDBMDirectExpr)> builder) {
SDBMDialect *dialect = expr.getDialect();
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
if (negated)
constant = sumExpr.getRHS().getValue() - constant;
else
constant += sumExpr.getRHS().getValue();
if (constant != 0) {
auto sum = SDBMSumExpr::get(sumExpr.getLHS(),
SDBMConstantExpr::get(dialect, constant));
return builder(sum);
} else {
return builder(sumExpr.getLHS());
}
}
if (constant != 0)
return builder(SDBMSumExpr::get(
expr.cast<SDBMTermExpr>(),
SDBMConstantExpr::get(dialect, negated ? -constant : constant)));
return expr;
}
// Construct an expression lhs + constant while maintaining the canonical form
// of the SDBM expressions, in particular sink the constant expression to the
// nearest sum expression in the left subtree of the expression tree.
static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) {
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
return addConstantAndSink<SDBMExpr>(
lhsDiff.getLHS(), constant, /*negated=*/false,
[lhsDiff](SDBMDirectExpr e) {
return SDBMDiffExpr::get(e, lhsDiff.getRHS());
});
if (auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>())
return addConstantAndSink<SDBMExpr>(
lhsNeg.getVar(), constant, /*negated=*/true,
[](SDBMDirectExpr e) { return SDBMNegExpr::get(e); });
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>())
return addConstantAndSink<SDBMExpr>(lhsSum, constant, /*negated=*/false,
[](SDBMDirectExpr e) { return e; });
if (constant != 0)
return SDBMSumExpr::get(lhs.cast<SDBMTermExpr>(),
SDBMConstantExpr::get(lhs.getDialect(), constant));
return lhs;
}
// Build a difference expression given a direct expression and a negation
// expression.
static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) {
// Fold (x + C) - (x + D) = C - D.
if (lhs.getTerm() == rhs.getVar().getTerm())
return SDBMConstantExpr::get(
lhs.getDialect(), lhs.getConstant() - rhs.getVar().getConstant());
return SDBMDiffExpr::get(
addConstantAndSink<SDBMDirectExpr>(lhs, -rhs.getVar().getConstant(),
/*negated=*/false,
[](SDBMDirectExpr e) { return e; }),
rhs.getVar().getTerm());
}
// Try folding an expression (lhs + rhs) where at least one of the operands
// contains a negated variable, i.e. is a negation or a difference expression.
static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) {
// If exactly one of LHS, RHS is a negation expression, we can construct
// a difference expression, which is a special kind in SDBM.
auto lhsDirect = lhs.dyn_cast<SDBMDirectExpr>();
auto rhsDirect = rhs.dyn_cast<SDBMDirectExpr>();
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
if (lhsDirect && rhsNeg)
return buildDiffExpr(lhsDirect, rhsNeg);
if (lhsNeg && rhsDirect)
return buildDiffExpr(rhsDirect, lhsNeg);
// If a subexpression appears in a diff expression on the LHS(RHS) of a
// sum expression where it also appears on the RHS(LHS) with the opposite
// sign, we can simplify it away and obtain the SDBM form.
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
// -(x + A) + ((x + B) - y) = -(y + (A - B))
if (lhsNeg && rhsDiff &&
lhsNeg.getVar().getTerm() == rhsDiff.getLHS().getTerm()) {
int64_t constant =
lhsNeg.getVar().getConstant() - rhsDiff.getLHS().getConstant();
// RHS of the diff is a term expression, its sum with a constant is a direct
// expression.
return SDBMNegExpr::get(
addConstant(rhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
}
// (x + A) + ((y + B) - x) = (y + B) + A.
if (lhsDirect && rhsDiff && lhsDirect.getTerm() == rhsDiff.getRHS())
return addConstant(rhsDiff.getLHS(), lhsDirect.getConstant());
// ((x + A) - y) + (-(x + B)) = -(y + (B - A)).
if (lhsDiff && rhsNeg &&
lhsDiff.getLHS().getTerm() == rhsNeg.getVar().getTerm()) {
int64_t constant =
rhsNeg.getVar().getConstant() - lhsDiff.getLHS().getConstant();
// RHS of the diff is a term expression, its sum with a constant is a direct
// expression.
return SDBMNegExpr::get(
addConstant(lhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
}
// ((x + A) - y) + (y + B) = (x + A) + B.
if (rhsDirect && lhsDiff && rhsDirect.getTerm() == lhsDiff.getRHS())
return addConstant(lhsDiff.getLHS(), rhsDirect.getConstant());
return {};
}
Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return {};
// In a "add" AffineExpr, the constant always appears on the right. If
// there were two constants, they would have been folded away.
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
// If RHS is a constant, we can always extend the SDBM expression to
// include it by sinking the constant into the nearest sum expression.
if (auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>()) {
int64_t constant = rhsConstant.getValue();
auto varying = lhs.dyn_cast<SDBMVaryingExpr>();
assert(varying && "unexpected uncanonicalized sum of constants");
return addConstant(varying, constant);
}
// Try building a difference expression if one of the values is negated,
// or check if a difference on either hand side cancels out the outer term
// so as to remain correct within SDBM. Return null otherwise.
return foldSumDiff(lhs, rhs);
}
SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
// Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
AffineExprMatcher x, C;
AffineExprMatcher pattern = (x.floorDiv(C)) * C;
if (pattern.match(expr)) {
if (SDBMExpr converted = visit(x.matched())) {
if (auto varConverted = converted.dyn_cast<SDBMTermExpr>())
// TODO: return varConverted.stripe(C.getConstantValue());
return SDBMStripeExpr::get(
varConverted,
SDBMConstantExpr::get(dialect,
C.getMatchedConstantValue().getValue()));
}
}
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return {};
// In a "mul" AffineExpr, the constant always appears on the right. If
// there were two constants, they would have been folded away.
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
if (!rhsConstant)
return {};
// The only supported "multiplication" expression is an SDBM is dimension
// negation, that is a product of dimension and constant -1.
if (rhsConstant.getValue() != -1)
return {};
if (auto lhsVar = lhs.dyn_cast<SDBMTermExpr>())
return SDBMNegExpr::get(lhsVar);
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
return SDBMNegator().visitDiff(lhsDiff);
// Other multiplications are not allowed in SDBM.
return {};
}
SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return {};
// 'mod' can only be converted to SDBM if its LHS is a direct expression
// and its RHS is a constant. Then it `x mod c = x - x stripe c`.
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
auto lhsVar = lhs.dyn_cast<SDBMDirectExpr>();
if (!lhsVar || !rhsConstant)
return {};
return SDBMDiffExpr::get(lhsVar,
SDBMStripeExpr::get(lhsVar, rhsConstant));
}
// `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
// Dimensions, symbols and constants are converted trivially.
SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
return SDBMConstantExpr::get(dialect, expr.getValue());
}
SDBMExpr visitDimExpr(AffineDimExpr expr) {
return SDBMDimExpr::get(dialect, expr.getPosition());
}
SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
return SDBMSymbolExpr::get(dialect, expr.getPosition());
}
SDBMDialect *dialect;
} converter;
converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>();
if (auto result = converter.visit(affine))
return result;
return None;
}
//===----------------------------------------------------------------------===//
// SDBMDiffExpr
//===----------------------------------------------------------------------===//
SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
assert(lhs && "expected SDBM dimension");
assert(rhs && "expected SDBM dimension");
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
return uniquer.get<detail::SDBMDiffExprStorage>(/*initFn=*/{}, lhs, rhs);
}
SDBMDirectExpr SDBMDiffExpr::getLHS() const {
return static_cast<ImplType *>(impl)->lhs;
}
SDBMTermExpr SDBMDiffExpr::getRHS() const {
return static_cast<ImplType *>(impl)->rhs;
}
//===----------------------------------------------------------------------===//
// SDBMDirectExpr
//===----------------------------------------------------------------------===//
SDBMTermExpr SDBMDirectExpr::getTerm() {
if (auto sum = dyn_cast<SDBMSumExpr>())
return sum.getLHS();
return cast<SDBMTermExpr>();
}
int64_t SDBMDirectExpr::getConstant() {
if (auto sum = dyn_cast<SDBMSumExpr>())
return sum.getRHS().getValue();
return 0;
}
//===----------------------------------------------------------------------===//
// SDBMStripeExpr
//===----------------------------------------------------------------------===//
SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
SDBMConstantExpr stripeFactor) {
assert(var && "expected SDBM variable expression");
assert(stripeFactor && "expected non-null stripe factor");
if (stripeFactor.getValue() <= 0)
llvm::report_fatal_error("non-positive stripe factor");
StorageUniquer &uniquer = var.getDialect()->getUniquer();
return uniquer.get<detail::SDBMBinaryExprStorage>(
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
stripeFactor);
}
SDBMDirectExpr SDBMStripeExpr::getLHS() const {
if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
return lhs.cast<SDBMDirectExpr>();
return {};
}
SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
return static_cast<ImplType *>(impl)->rhs;
}
//===----------------------------------------------------------------------===//
// SDBMInputExpr
//===----------------------------------------------------------------------===//
unsigned SDBMInputExpr::getPosition() const {
return static_cast<ImplType *>(impl)->position;
}
//===----------------------------------------------------------------------===//
// SDBMDimExpr
//===----------------------------------------------------------------------===//
SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
assert(dialect && "expected non-null dialect");
auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
storage->dialect = dialect;
};
StorageUniquer &uniquer = dialect->getUniquer();
return uniquer.get<detail::SDBMTermExprStorage>(
assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
}
//===----------------------------------------------------------------------===//
// SDBMSymbolExpr
//===----------------------------------------------------------------------===//
SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
assert(dialect && "expected non-null dialect");
auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
storage->dialect = dialect;
};
StorageUniquer &uniquer = dialect->getUniquer();
return uniquer.get<detail::SDBMTermExprStorage>(
assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
}
//===----------------------------------------------------------------------===//
// SDBMConstantExpr
//===----------------------------------------------------------------------===//
SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
assert(dialect && "expected non-null dialect");
auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) {
storage->dialect = dialect;
};
StorageUniquer &uniquer = dialect->getUniquer();
return uniquer.get<detail::SDBMConstantExprStorage>(assignCtx, value);
}
int64_t SDBMConstantExpr::getValue() const {
return static_cast<ImplType *>(impl)->constant;
}
//===----------------------------------------------------------------------===//
// SDBMNegExpr
//===----------------------------------------------------------------------===//
SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
assert(var && "expected non-null SDBM direct expression");
StorageUniquer &uniquer = var.getDialect()->getUniquer();
return uniquer.get<detail::SDBMNegExprStorage>(/*initFn=*/{}, var);
}
SDBMDirectExpr SDBMNegExpr::getVar() const {
return static_cast<ImplType *>(impl)->expr;
}
SDBMExpr mlir::ops_assertions::operator+(SDBMExpr lhs, SDBMExpr rhs) {
if (auto folded = foldSumDiff(lhs, rhs))
return folded;
assert(!(lhs.isa<SDBMNegExpr>() && rhs.isa<SDBMNegExpr>()) &&
"a sum of negated expressions is a negation of a sum of variables and "
"not a correct SDBM");
// Fold (x - y) + (y - x) = 0.
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
if (lhsDiff && rhsDiff) {
if (lhsDiff.getLHS() == rhsDiff.getRHS() &&
lhsDiff.getRHS() == rhsDiff.getLHS())
return SDBMConstantExpr::get(lhs.getDialect(), 0);
}
// If LHS is a constant and RHS is not, swap the order to get into a supported
// sum case. From now on, RHS must be a constant.
auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
if (!rhsConstant && lhsConstant) {
std::swap(lhs, rhs);
std::swap(lhsConstant, rhsConstant);
}
assert(rhsConstant && "at least one operand must be a constant");
// Constant-fold if LHS is also a constant.
if (lhsConstant)
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
rhsConstant.getValue());
return addConstant(lhs.cast<SDBMVaryingExpr>(), rhsConstant.getValue());
}
SDBMExpr mlir::ops_assertions::operator-(SDBMExpr lhs, SDBMExpr rhs) {
// Fold x - x == 0.
if (lhs == rhs)
return SDBMConstantExpr::get(lhs.getDialect(), 0);
// LHS and RHS may be constants.
auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
// Constant fold if both LHS and RHS are constants.
if (lhsConstant && rhsConstant)
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() -
rhsConstant.getValue());
// Replace a difference with a sum with a negated value if one of LHS and RHS
// is a constant:
// x - C == x + (-C);
// C - x == -x + C.
// This calls into operator+ for further simplification.
if (rhsConstant)
return lhs + (-rhsConstant);
if (lhsConstant)
return -rhs + lhsConstant;
return buildDiffExpr(lhs.cast<SDBMDirectExpr>(), (-rhs).cast<SDBMNegExpr>());
}
SDBMExpr mlir::ops_assertions::stripe(SDBMExpr expr, SDBMExpr factor) {
auto constantFactor = factor.cast<SDBMConstantExpr>();
assert(constantFactor.getValue() > 0 && "non-positive stripe");
// Fold x # 1 = x.
if (constantFactor.getValue() == 1)
return expr;
return SDBMStripeExpr::get(expr.cast<SDBMDirectExpr>(), constantFactor);
}

View File

@ -1,137 +0,0 @@
//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This holds implementation details of SDBMExpr, in particular underlying
// storage types.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_SDBMEXPRDETAIL_H
#define MLIR_IR_SDBMEXPRDETAIL_H
#include "mlir/Dialect/SDBM/SDBMExpr.h"
#include "mlir/Support/StorageUniquer.h"
namespace mlir {
class SDBMDialect;
namespace detail {
// Base storage class for SDBMExpr.
struct SDBMExprStorage : public StorageUniquer::BaseStorage {
SDBMExprKind getKind() { return kind; }
SDBMDialect *dialect;
SDBMExprKind kind;
};
// Storage class for SDBM sum and stripe expressions.
struct SDBMBinaryExprStorage : public SDBMExprStorage {
using KeyTy = std::tuple<unsigned, SDBMDirectExpr, SDBMConstantExpr>;
bool operator==(const KeyTy &key) const {
return static_cast<SDBMExprKind>(std::get<0>(key)) == kind &&
std::get<1>(key) == lhs && std::get<2>(key) == rhs;
}
static SDBMBinaryExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMBinaryExprStorage>();
result->lhs = std::get<1>(key);
result->rhs = std::get<2>(key);
result->dialect = result->lhs.getDialect();
result->kind = static_cast<SDBMExprKind>(std::get<0>(key));
return result;
}
SDBMDirectExpr lhs;
SDBMConstantExpr rhs;
};
// Storage class for SDBM difference expressions.
struct SDBMDiffExprStorage : public SDBMExprStorage {
using KeyTy = std::pair<SDBMDirectExpr, SDBMTermExpr>;
bool operator==(const KeyTy &key) const {
return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
}
static SDBMDiffExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMDiffExprStorage>();
result->lhs = std::get<0>(key);
result->rhs = std::get<1>(key);
result->dialect = result->lhs.getDialect();
result->kind = SDBMExprKind::Diff;
return result;
}
SDBMDirectExpr lhs;
SDBMTermExpr rhs;
};
// Storage class for SDBM constant expressions.
struct SDBMConstantExprStorage : public SDBMExprStorage {
using KeyTy = int64_t;
bool operator==(const KeyTy &key) const { return constant == key; }
static SDBMConstantExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMConstantExprStorage>();
result->constant = key;
result->kind = SDBMExprKind::Constant;
return result;
}
int64_t constant;
};
// Storage class for SDBM dimension and symbol expressions.
struct SDBMTermExprStorage : public SDBMExprStorage {
using KeyTy = std::pair<unsigned, unsigned>;
bool operator==(const KeyTy &key) const {
return kind == static_cast<SDBMExprKind>(key.first) &&
position == key.second;
}
static SDBMTermExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMTermExprStorage>();
result->kind = static_cast<SDBMExprKind>(key.first);
result->position = key.second;
return result;
}
unsigned position;
};
// Storage class for SDBM negation expressions.
struct SDBMNegExprStorage : public SDBMExprStorage {
using KeyTy = SDBMDirectExpr;
bool operator==(const KeyTy &key) const { return key == expr; }
static SDBMNegExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMNegExprStorage>();
result->expr = key;
result->dialect = key.getDialect();
result->kind = SDBMExprKind::Neg;
return result;
}
SDBMDirectExpr expr;
};
} // end namespace detail
} // end namespace mlir
#endif // MLIR_IR_SDBMEXPRDETAIL_H

View File

@ -1,5 +1,4 @@
add_subdirectory(CAPI)
add_subdirectory(SDBM)
add_subdirectory(lib)
if(MLIR_ENABLE_BINDINGS_PYTHON)
@ -75,7 +74,6 @@ set(MLIR_TEST_DEPENDS
mlir-lsp-server
mlir-opt
mlir-reduce
mlir-sdbm-api-test
mlir-tblgen
mlir-translate
mlir_runner_utils

View File

@ -1,19 +0,0 @@
set(LLVM_LINK_COMPONENTS
Core
Support
)
add_llvm_executable(mlir-sdbm-api-test
sdbm-api-test.cpp
)
llvm_update_compile_flags(mlir-sdbm-api-test)
target_link_libraries(mlir-sdbm-api-test
PRIVATE
MLIRIR
MLIRSDBM
MLIRSupport
)
target_include_directories(mlir-sdbm-api-test PRIVATE ..)

View File

@ -1 +0,0 @@
config.suffixes.add('.cpp')

View File

@ -1,201 +0,0 @@
//===- sdbm-api-test.cpp - Tests for SDBM expression APIs -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// RUN: mlir-sdbm-api-test | FileCheck %s
#include "mlir/Dialect/SDBM/SDBM.h"
#include "mlir/Dialect/SDBM/SDBMDialect.h"
#include "mlir/Dialect/SDBM/SDBMExpr.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/Support/raw_ostream.h"
#include "APITest.h"
using namespace mlir;
static MLIRContext *ctx() {
static thread_local MLIRContext context;
static thread_local bool once =
(context.getOrLoadDialect<SDBMDialect>(), true);
(void)once;
return &context;
}
static SDBMDialect *dialect() {
static thread_local SDBMDialect *d = nullptr;
if (!d) {
d = ctx()->getOrLoadDialect<SDBMDialect>();
}
return d;
}
static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); }
static SDBMExpr symb(unsigned pos) {
return SDBMSymbolExpr::get(dialect(), pos);
}
namespace {
using namespace mlir::ops_assertions;
TEST_FUNC(SDBM_SingleConstraint) {
// Build an SDBM defined by
// d0 - 3 <= 0 <=> d0 <= 3.
auto sdbm = SDBM::get(dim(0) - 3, llvm::None);
// CHECK: cst d0
// CHECK-NEXT: cst inf 3
// CHECK-NEXT: d0 inf inf
sdbm.print(llvm::outs());
}
TEST_FUNC(SDBM_Equality) {
// Build an SDBM defined by
//
// d0 - d1 - 3 = 0
// <=> {d0 - d1 - 3 <= 0 and d0 - d1 - 3 >= 0}
// <=> {d0 - d1 <= 3 and d1 - d0 <= -3}.
auto sdbm = SDBM::get(llvm::None, dim(0) - dim(1) - 3);
// CHECK: cst d0 d1
// CHECK-NEXT: cst inf inf inf
// CHECK-NEXT: d0 inf inf -3
// CHECK-NEXT: d1 inf 3 inf
sdbm.print(llvm::outs());
}
TEST_FUNC(SDBM_TrivialSimplification) {
// Build an SDBM defined by
//
// d0 - 3 <= 0 <=> d0 <= 3
// d0 - 5 <= 0 <=> d0 <= 5
//
// which should get simplified on construction to only the former.
auto sdbm = SDBM::get({dim(0) - 3, dim(0) - 5}, llvm::None);
// CHECK: cst d0
// CHECK-NEXT: cst inf 3
// CHECK-NEXT: d0 inf inf
sdbm.print(llvm::outs());
}
TEST_FUNC(SDBM_StripeInducedIneqs) {
// Build an SDBM defined by d1 = d0 # 3, which induces the constraints
//
// d1 - d0 <= 0
// d0 - d1 <= 3 - 1 = 2
auto sdbm = SDBM::get(llvm::None, dim(1) - stripe(dim(0), 3));
// CHECK: cst d0 d1
// CHECK-NEXT: cst inf inf inf
// CHECK-NEXT: d0 inf inf 0
// CHECK-NEXT: d1 inf 2 0
// CHECK-NEXT: d1 = d0 # 3
sdbm.print(llvm::outs());
}
TEST_FUNC(SDBM_StripeTemporaries) {
// Build an SDBM defined by d0 # 3 <= 0, which creates a temporary
// t0 = d0 # 3 leading to a constraint t0 <= 0 and the stripe-induced
// constraints
//
// t0 - d0 <= 0
// d0 - t0 <= 3 - 1 = 2
auto sdbm = SDBM::get(stripe(dim(0), 3), llvm::None);
// CHECK: cst d0 t0
// CHECK-NEXT: cst inf inf 0
// CHECK-NEXT: d0 inf inf 0
// CHECK-NEXT: t0 inf 2 inf
// CHECK-NEXT: t0 = d0 # 3
sdbm.print(llvm::outs());
}
TEST_FUNC(SDBM_ElideInducedInequalities) {
// Build an SDBM defined by a single stripe equality d0 = s0 # 3 and make sure
// the induced inequalities are not present after converting the SDBM back
// into lists of expressions.
auto sdbm = SDBM::get(llvm::None, {dim(0) - stripe(symb(0), 3)});
SmallVector<SDBMExpr, 4> eqs, ineqs;
sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
// CHECK-EMPTY:
for (auto ineq : ineqs)
ineq.print(llvm::outs() << '\n');
llvm::outs() << "\n";
// CHECK: d0 - s0 # 3
// CHECK-EMPTY:
for (auto eq : eqs)
eq.print(llvm::outs() << '\n');
llvm::outs() << "\n\n";
}
TEST_FUNC(SDBM_StripeTightening) {
// Build an SDBM defined by
//
// d0 = s0 # 3 # 5
// s0 # 3 # 5 - d1 + 42 = 0
// s0 # 3 - d0 <= 2
//
// where the last inequality is tighter than that induced by the first stripe
// equality (s0 # 3 - d0 <= 5 - 1 = 4). Check that the conversion from SDBM
// back to the lists of constraints conserves both the stripe equality and the
// tighter inequality.
auto s = stripe(stripe(symb(0), 3), 5);
auto tight = stripe(symb(0), 3) - dim(0) - 2;
auto sdbm = SDBM::get({tight}, {s - dim(0), s - dim(1) + 42});
SmallVector<SDBMExpr, 4> eqs, ineqs;
sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
// CHECK: s0 # 3 + -2 - d0
// CHECK-EMPTY:
for (auto ineq : ineqs)
ineq.print(llvm::outs() << '\n');
llvm::outs() << "\n";
// CHECK-DAG: d1 + -42 - d0
// CHECK-DAG: d0 - s0 # 3 # 5
for (auto eq : eqs)
eq.print(llvm::outs() << '\n');
llvm::outs() << "\n\n";
}
TEST_FUNC(SDBM_StripeTransitive) {
// Build an SDBM defined by
//
// d0 = d1 # 3
// d0 = d2 # 7
//
// where the same dimension is declared equal to two stripe expressions over
// different variables. This is practically handled by introducing a
// temporary variable for the second stripe expression and adding an equality
// constraint between this variable and the original dimension variable.
auto sdbm = SDBM::get(
llvm::None, {stripe(dim(1), 3) - dim(0), stripe(dim(2), 7) - dim(0)});
// CHECK: cst d0 d1 d2 t0
// CHECK-NEXT: cst inf inf inf inf inf
// CHECK-NEXT: d0 inf 0 2 inf 0
// CHECK-NEXT: d1 inf 0 inf inf inf
// CHECK-NEXT: d2 inf inf inf inf 0
// CHECK-NEXT: t0 inf 0 inf 6 inf
// CHECK-NEXT: t0 = d2 # 7
// CHECK-NEXT: d0 = d1 # 3
sdbm.print(llvm::outs());
}
} // end namespace
int main() {
RUN_TESTS();
return 0;
}

View File

@ -65,7 +65,6 @@ tools = [
'mlir-linalg-ods-gen',
'mlir-linalg-ods-yaml-gen',
'mlir-reduce',
'mlir-sdbm-api-test',
]
# The following tools are optional

View File

@ -21,7 +21,6 @@
// CHECK-NEXT: quant
// CHECK-NEXT: rocdl
// CHECK-NEXT: scf
// CHECK-NEXT: sdbm
// CHECK-NEXT: shape
// CHECK-NEXT: sparse_tensor
// CHECK-NEXT: spv

View File

@ -11,5 +11,4 @@ add_subdirectory(Interfaces)
add_subdirectory(IR)
add_subdirectory(Pass)
add_subdirectory(Rewrite)
add_subdirectory(SDBM)
add_subdirectory(TableGen)

View File

@ -1,7 +0,0 @@
add_mlir_unittest(MLIRSDBMTests
SDBMTest.cpp
)
target_link_libraries(MLIRSDBMTests
PRIVATE
MLIRSDBM
)

View File

@ -1,449 +0,0 @@
//===- SDBMTest.cpp - SDBM expression unit tests --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SDBM/SDBM.h"
#include "mlir/Dialect/SDBM/SDBMDialect.h"
#include "mlir/Dialect/SDBM/SDBMExpr.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/MLIRContext.h"
#include "gtest/gtest.h"
#include "llvm/ADT/DenseSet.h"
using namespace mlir;
static MLIRContext *ctx() {
static thread_local MLIRContext context;
context.getOrLoadDialect<SDBMDialect>();
return &context;
}
static SDBMDialect *dialect() {
static thread_local SDBMDialect *d = nullptr;
if (!d) {
d = ctx()->getOrLoadDialect<SDBMDialect>();
}
return d;
}
static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); }
static SDBMExpr symb(unsigned pos) {
return SDBMSymbolExpr::get(dialect(), pos);
}
namespace {
using namespace mlir::ops_assertions;
TEST(SDBMOperators, Add) {
auto expr = dim(0) + 42;
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sumExpr);
EXPECT_EQ(sumExpr.getLHS(), dim(0));
EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
}
TEST(SDBMOperators, AddFolding) {
auto constant = SDBMConstantExpr::get(dialect(), 2) + 42;
auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(constantExpr);
EXPECT_EQ(constantExpr.getValue(), 44);
auto expr = (dim(0) + 10) + 32;
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sumExpr);
EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
expr = dim(0) + SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1));
auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();
ASSERT_TRUE(diffExpr);
EXPECT_EQ(diffExpr.getLHS(), dim(0));
EXPECT_EQ(diffExpr.getRHS(), dim(1));
auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0);
EXPECT_EQ(inverted, expr);
// Check that opposite values cancel each other, and that we elide the zero
// constant.
expr = dim(0) + 42;
auto onlyDim = expr - 42;
EXPECT_EQ(onlyDim, dim(0));
// Check that we can sink a constant under a negation.
expr = -(dim(0) + 2);
auto negatedSum = (expr + 10).dyn_cast<SDBMNegExpr>();
ASSERT_TRUE(negatedSum);
auto sum = negatedSum.getVar().dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sum);
EXPECT_EQ(sum.getRHS().getValue(), -8);
// Sum with zero is the same as the original expression.
EXPECT_EQ(dim(0) + 0, dim(0));
// Sum of opposite differences is zero.
auto diffOfDiffs =
((dim(0) - dim(1)) + (dim(1) - dim(0))).dyn_cast<SDBMConstantExpr>();
EXPECT_EQ(diffOfDiffs.getValue(), 0);
}
TEST(SDBMOperators, AddNegativeTerms) {
const int64_t A = 7;
const int64_t B = -5;
auto x = SDBMDimExpr::get(dialect(), 0);
auto y = SDBMDimExpr::get(dialect(), 1);
// Check the simplification patterns in addition where one of the variables is
// cancelled out and the result remains an SDBM.
EXPECT_EQ(-(x + A) + ((x + B) - y), -(y + (A - B)));
EXPECT_EQ((x + A) + ((y + B) - x), (y + B) + A);
EXPECT_EQ(((x + A) - y) + (-(x + B)), -(y + (B - A)));
EXPECT_EQ(((x + A) - y) + (y + B), (x + A) + B);
}
TEST(SDBMOperators, Diff) {
auto expr = dim(0) - dim(1);
auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();
ASSERT_TRUE(diffExpr);
EXPECT_EQ(diffExpr.getLHS(), dim(0));
EXPECT_EQ(diffExpr.getRHS(), dim(1));
}
TEST(SDBMOperators, DiffFolding) {
auto constant = SDBMConstantExpr::get(dialect(), 10) - 3;
auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(constantExpr);
EXPECT_EQ(constantExpr.getValue(), 7);
auto expr = dim(0) - 3;
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sumExpr);
EXPECT_EQ(sumExpr.getRHS().getValue(), -3);
auto zero = dim(0) - dim(0);
constantExpr = zero.dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(constantExpr);
EXPECT_EQ(constantExpr.getValue(), 0);
// Check that the constant terms in difference-of-sums are folded.
// (d0 - 3) - (d1 - 5) = (d0 + 2) - d1
auto diffOfSums = ((dim(0) - 3) - (dim(1) - 5)).dyn_cast<SDBMDiffExpr>();
ASSERT_TRUE(diffOfSums);
auto lhs = diffOfSums.getLHS().dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(lhs);
EXPECT_EQ(lhs.getLHS(), dim(0));
EXPECT_EQ(lhs.getRHS().getValue(), 2);
EXPECT_EQ(diffOfSums.getRHS(), dim(1));
// Check that identical dimensions with opposite signs cancel each other.
auto cstOnly = ((dim(0) + 42) - dim(0)).dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(cstOnly);
EXPECT_EQ(cstOnly.getValue(), 42);
// Check that identical terms in sum of diffs cancel out.
auto dimOnly = (-dim(0) + (dim(0) - dim(1)));
EXPECT_EQ(dimOnly, -dim(1));
dimOnly = (dim(0) - dim(1)) + (-dim(0));
EXPECT_EQ(dimOnly, -dim(1));
dimOnly = (dim(0) - dim(1)) + dim(1);
EXPECT_EQ(dimOnly, dim(0));
dimOnly = dim(0) + (dim(1) - dim(0));
EXPECT_EQ(dimOnly, dim(1));
// Top-level zero constant is fine.
cstOnly = (-symb(1) + symb(1)).dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(cstOnly);
EXPECT_EQ(cstOnly.getValue(), 0);
}
TEST(SDBMOperators, Negate) {
auto sum = dim(0) + 3;
auto negated = (-sum).dyn_cast<SDBMNegExpr>();
ASSERT_TRUE(negated);
EXPECT_EQ(negated.getVar(), sum);
}
TEST(SDBMOperators, Stripe) {
auto expr = stripe(dim(0), 3);
auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>();
ASSERT_TRUE(stripeExpr);
EXPECT_EQ(stripeExpr.getLHS(), dim(0));
EXPECT_EQ(stripeExpr.getStripeFactor().getValue(), 3);
}
TEST(SDBM, RoundTripEqs) {
// Build an SDBM defined by
//
// d0 = s0 # 3 # 5
// s0 # 3 # 5 - d1 + 42 = 0
//
// and perform a double round-trip between the "list of equalities" and SDBM
// representation. After the first round-trip, the equalities may be
// different due to simplification or equivalent substitutions (e.g., the
// second equality may become d0 - d1 + 42 = 0). However, there should not
// be any further simplification after the second round-trip,
// Build the SDBM from a pair of equalities and extract back the lists of
// inequalities and equalities. Check that all equalities are properly
// detected and none of them decayed into inequalities.
auto s = stripe(stripe(symb(0), 3), 5);
auto sdbm = SDBM::get(llvm::None, {s - dim(0), s - dim(1) + 42});
SmallVector<SDBMExpr, 4> eqs, ineqs;
sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
ASSERT_TRUE(ineqs.empty());
// Do the second round-trip.
auto sdbm2 = SDBM::get(llvm::None, eqs);
SmallVector<SDBMExpr, 4> eqs2, ineqs2;
sdbm2.getSDBMExpressions(dialect(), ineqs2, eqs2);
ASSERT_EQ(eqs.size(), eqs2.size());
// Check that the sets of equalities are equal, their order is not relevant.
llvm::DenseSet<SDBMExpr> eqSet, eq2Set;
eqSet.insert(eqs.begin(), eqs.end());
eq2Set.insert(eqs2.begin(), eqs2.end());
EXPECT_EQ(eqSet, eq2Set);
}
TEST(SDBMExpr, Constant) {
// We can create constants and query them.
auto expr = SDBMConstantExpr::get(dialect(), 42);
EXPECT_EQ(expr.getValue(), 42);
// Two separately created constants with identical values are trivially equal.
auto expr2 = SDBMConstantExpr::get(dialect(), 42);
EXPECT_EQ(expr, expr2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMConstantExpr>());
}
TEST(SDBMExpr, Dim) {
// We can create dimension expressions and query them.
auto expr = SDBMDimExpr::get(dialect(), 0);
EXPECT_EQ(expr.getPosition(), 0u);
// Two separately created dimensions with the same position are trivially
// equal.
auto expr2 = SDBMDimExpr::get(dialect(), 0);
EXPECT_EQ(expr, expr2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMDimExpr>());
EXPECT_TRUE(generic.isa<SDBMInputExpr>());
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
// Dimensions are not Symbols.
auto symbol = SDBMSymbolExpr::get(dialect(), 0);
EXPECT_NE(expr, symbol);
EXPECT_FALSE(expr.isa<SDBMSymbolExpr>());
}
TEST(SDBMExpr, Symbol) {
// We can create symbol expressions and query them.
auto expr = SDBMSymbolExpr::get(dialect(), 0);
EXPECT_EQ(expr.getPosition(), 0u);
// Two separately created symbols with the same position are trivially equal.
auto expr2 = SDBMSymbolExpr::get(dialect(), 0);
EXPECT_EQ(expr, expr2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMSymbolExpr>());
EXPECT_TRUE(generic.isa<SDBMInputExpr>());
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
// Dimensions are not Symbols.
auto symbol = SDBMDimExpr::get(dialect(), 0);
EXPECT_NE(expr, symbol);
EXPECT_FALSE(expr.isa<SDBMDimExpr>());
}
TEST(SDBMExpr, Stripe) {
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
auto cst0 = SDBMConstantExpr::get(dialect(), 0);
auto var = SDBMSymbolExpr::get(dialect(), 0);
// We can create stripe expressions and query them.
auto expr = SDBMStripeExpr::get(var, cst2);
EXPECT_EQ(expr.getLHS(), var);
EXPECT_EQ(expr.getStripeFactor(), cst2);
// Two separately created stripe expressions with the same LHS and RHS are
// trivially equal.
auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(dialect(), 0), cst2);
EXPECT_EQ(expr, expr2);
// Stripes can be nested.
SDBMStripeExpr::get(expr, SDBMConstantExpr::get(dialect(), 4));
// Non-positive stripe factors are not allowed.
EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive");
// Stripes can have sums on the LHS.
SDBMStripeExpr::get(SDBMSumExpr::get(var, cst2), cst2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
TEST(SDBMExpr, Neg) {
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
auto var = SDBMSymbolExpr::get(dialect(), 0);
auto stripe = SDBMStripeExpr::get(var, cst2);
// We can create negation expressions and query them.
auto expr = SDBMNegExpr::get(var);
EXPECT_EQ(expr.getVar(), var);
auto expr2 = SDBMNegExpr::get(stripe);
EXPECT_EQ(expr2.getVar(), stripe);
// Neg expressions are trivially comparable.
EXPECT_EQ(expr, SDBMNegExpr::get(var));
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMNegExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
TEST(SDBMExpr, Sum) {
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
auto var = SDBMSymbolExpr::get(dialect(), 0);
auto stripe = SDBMStripeExpr::get(var, cst2);
// We can create sum expressions and query them.
auto expr = SDBMSumExpr::get(var, cst2);
EXPECT_EQ(expr.getLHS(), var);
EXPECT_EQ(expr.getRHS(), cst2);
auto expr2 = SDBMSumExpr::get(stripe, cst2);
EXPECT_EQ(expr2.getLHS(), stripe);
EXPECT_EQ(expr2.getRHS(), cst2);
// Sum expressions are trivially comparable.
EXPECT_EQ(expr, SDBMSumExpr::get(var, cst2));
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMSumExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
TEST(SDBMExpr, Diff) {
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
auto var = SDBMSymbolExpr::get(dialect(), 0);
auto stripe = SDBMStripeExpr::get(var, cst2);
// We can create sum expressions and query them.
auto expr = SDBMDiffExpr::get(var, stripe);
EXPECT_EQ(expr.getLHS(), var);
EXPECT_EQ(expr.getRHS(), stripe);
auto expr2 = SDBMDiffExpr::get(stripe, var);
EXPECT_EQ(expr2.getLHS(), stripe);
EXPECT_EQ(expr2.getRHS(), var);
// Sum expressions are trivially comparable.
EXPECT_EQ(expr, SDBMDiffExpr::get(var, stripe));
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMDiffExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
TEST(SDBMExpr, AffineRoundTrip) {
// Build an expression (s0 - s0 # 2)
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
auto var = SDBMSymbolExpr::get(dialect(), 0);
auto stripe = SDBMStripeExpr::get(var, cst2);
auto expr = SDBMDiffExpr::get(var, stripe);
// Check that it can be converted to AffineExpr and back, i.e. stripe
// detection works correctly.
Optional<SDBMExpr> roundtripped =
SDBMExpr::tryConvertAffineExpr(expr.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(expr));
// Check that (s0 # 2 # 5) can be converted to AffineExpr, i.e. stripe
// detection supports nested expressions.
auto cst5 = SDBMConstantExpr::get(dialect(), 5);
auto outerStripe = SDBMStripeExpr::get(stripe, cst5);
roundtripped = SDBMExpr::tryConvertAffineExpr(outerStripe.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(outerStripe));
// Check that ((s0 + 2) # 5) can be round-tripped through AffineExpr, i.e.
// stripe detection supports sum expressions.
auto inner = SDBMSumExpr::get(var, cst2);
auto stripeSum = SDBMStripeExpr::get(inner, cst5);
roundtripped = SDBMExpr::tryConvertAffineExpr(stripeSum.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(stripeSum));
// Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
// deeper expression tree.
auto sum = SDBMSumExpr::get(outerStripe, cst2);
auto diff = SDBMDiffExpr::get(sum, stripe);
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
// Check a nested stripe-sum combination.
auto cst7 = SDBMConstantExpr::get(dialect(), 7);
auto nestedStripe =
SDBMStripeExpr::get(SDBMSumExpr::get(stripeSum, cst2), cst7);
diff = SDBMDiffExpr::get(nestedStripe, stripe);
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
}
TEST(SDBMExpr, MatchStripeMulPattern) {
// Make sure conversion from AffineExpr recognizes multiplicative stripe
// pattern (x floordiv B) * B == x # B.
auto cst = getAffineConstantExpr(42, ctx());
auto dim = getAffineDimExpr(0, ctx());
auto floor = dim.floorDiv(cst);
auto mul = cst * floor;
Optional<SDBMExpr> converted = SDBMStripeExpr::tryConvertAffineExpr(mul);
ASSERT_TRUE(converted.hasValue());
EXPECT_TRUE(converted->isa<SDBMStripeExpr>());
}
TEST(SDBMExpr, NonSDBM) {
auto d0 = getAffineDimExpr(0, ctx());
auto d1 = getAffineDimExpr(1, ctx());
auto sum = d0 + d1;
auto c2 = getAffineConstantExpr(2, ctx());
auto prod = d0 * c2;
auto ceildiv = d1.ceilDiv(c2);
// The following are not valid SDBM expressions:
// - a sum of two variables
EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(sum).hasValue());
// - a variable with coefficient other than 1 or -1
EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(prod).hasValue());
// - a ceildiv expression
EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(ceildiv).hasValue());
}
} // end namespace