forked from OSchip/llvm-project
[mlir] Remove SDBM
This data structure and algorithm collection is no longer in use. Reviewed By: bondhugula Differential Revision: https://reviews.llvm.org/D105102
This commit is contained in:
parent
47215e1c62
commit
355216380b
|
@ -1,197 +0,0 @@
|
|||
//===- SDBM.h - MLIR SDBM declaration ---------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined
|
||||
// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SDBM_SDBM_H
|
||||
#define MLIR_DIALECT_SDBM_SDBM_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class MLIRContext;
|
||||
class SDBMDialect;
|
||||
class SDBMExpr;
|
||||
class SDBMTermExpr;
|
||||
|
||||
/// A utility class for SDBM to represent an integer with potentially infinite
|
||||
/// positive value. This uses the largest value of int64_t to represent infinity
|
||||
/// and redefines the arithmetic operators so that the infinity "saturates":
|
||||
/// inf + x = inf,
|
||||
/// inf - x = inf.
|
||||
/// If a sum of two finite values reaches the largest value of int64_t, the
|
||||
/// behavior of IntInfty is undefined (in practice, it asserts), similarly to
|
||||
/// regular signed integer overflow.
|
||||
class IntInfty {
|
||||
public:
|
||||
constexpr static int64_t infty = std::numeric_limits<int64_t>::max();
|
||||
|
||||
/*implicit*/ IntInfty(int64_t v) : value(v) {}
|
||||
|
||||
IntInfty &operator=(int64_t v) {
|
||||
value = v;
|
||||
return *this;
|
||||
}
|
||||
|
||||
static IntInfty infinity() { return IntInfty(infty); }
|
||||
|
||||
int64_t getValue() const { return value; }
|
||||
explicit operator int64_t() const { return value; }
|
||||
|
||||
bool isFinite() { return value != infty; }
|
||||
|
||||
private:
|
||||
int64_t value;
|
||||
};
|
||||
|
||||
inline IntInfty operator+(IntInfty lhs, IntInfty rhs) {
|
||||
if (!lhs.isFinite() || !rhs.isFinite())
|
||||
return IntInfty::infty;
|
||||
|
||||
// Check for overflows, treating the sum of two values adding up to INT_MAX as
|
||||
// overflow. Convert values to unsigned to get an extra bit and avoid the
|
||||
// undefined behavior of signed integer overflows.
|
||||
assert((lhs.getValue() <= 0 || rhs.getValue() <= 0 ||
|
||||
static_cast<uint64_t>(lhs.getValue()) +
|
||||
static_cast<uint64_t>(rhs.getValue()) <
|
||||
static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) &&
|
||||
"IntInfty overflow");
|
||||
// Check for underflows by converting values to unsigned to avoid undefined
|
||||
// behavior of signed integers perform the addition (bitwise result is same
|
||||
// because numbers are required to be two's complement in C++) and check if
|
||||
// the sign bit remains negative.
|
||||
assert((lhs.getValue() >= 0 || rhs.getValue() >= 0 ||
|
||||
((static_cast<uint64_t>(lhs.getValue()) +
|
||||
static_cast<uint64_t>(rhs.getValue())) >>
|
||||
63) == 1) &&
|
||||
"IntInfty underflow");
|
||||
|
||||
return lhs.getValue() + rhs.getValue();
|
||||
}
|
||||
|
||||
inline bool operator<(IntInfty lhs, IntInfty rhs) {
|
||||
return lhs.getValue() < rhs.getValue();
|
||||
}
|
||||
|
||||
inline bool operator<=(IntInfty lhs, IntInfty rhs) {
|
||||
return lhs.getValue() <= rhs.getValue();
|
||||
}
|
||||
|
||||
inline bool operator==(IntInfty lhs, IntInfty rhs) {
|
||||
return lhs.getValue() == rhs.getValue();
|
||||
}
|
||||
|
||||
inline bool operator!=(IntInfty lhs, IntInfty rhs) { return !(lhs == rhs); }
|
||||
|
||||
/// Striped difference-bound matrix is a representation of an integer set bound
|
||||
/// by a system of SDBMExprs interpreted as inequalities "expr <= 0".
|
||||
class SDBM {
|
||||
public:
|
||||
/// Obtain an SDBM from a list of SDBM expressions treated as inequalities and
|
||||
/// equalities with zero.
|
||||
static SDBM get(ArrayRef<SDBMExpr> inequalities,
|
||||
ArrayRef<SDBMExpr> equalities);
|
||||
|
||||
void getSDBMExpressions(SDBMDialect *dialect,
|
||||
SmallVectorImpl<SDBMExpr> &inequalities,
|
||||
SmallVectorImpl<SDBMExpr> &equalities);
|
||||
|
||||
void print(raw_ostream &os);
|
||||
void dump();
|
||||
|
||||
IntInfty operator()(int i, int j) { return at(i, j); }
|
||||
|
||||
private:
|
||||
/// Get the given element of the difference bounds matrix. First index
|
||||
/// corresponds to the negative term of the difference, second index
|
||||
/// corresponds to the positive term of the difference.
|
||||
IntInfty &at(int i, int j) { return matrix[i * getNumVariables() + j]; }
|
||||
|
||||
/// Populate `inequalities` and `equalities` based on the values at(row,col)
|
||||
/// and at(col,row) of the DBM. Depending on the values being finite and
|
||||
/// being subsumed by stripe expressions, this may or may not add elements to
|
||||
/// the lists of equalities and inequalities.
|
||||
void convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr,
|
||||
SDBMTermExpr colExpr,
|
||||
SmallVectorImpl<SDBMExpr> &inequalities,
|
||||
SmallVectorImpl<SDBMExpr> &equalities);
|
||||
|
||||
/// Populate `inequalities` based on the value at(pos,pos) of the DBM. Only
|
||||
/// adds new inequalities if the inequality is not trivially true.
|
||||
void convertDBMDiagonalElement(unsigned pos, SDBMTermExpr expr,
|
||||
SmallVectorImpl<SDBMExpr> &inequalities);
|
||||
|
||||
/// Get the total number of elements in the matrix.
|
||||
unsigned getNumVariables() const {
|
||||
return 1 + numDims + numSymbols + numTemporaries;
|
||||
}
|
||||
|
||||
/// Get the position in the matrix that corresponds to the given dimension.
|
||||
unsigned getDimPosition(unsigned position) const { return 1 + position; }
|
||||
|
||||
/// Get the position in the matrix that corresponds to the given symbol.
|
||||
unsigned getSymbolPosition(unsigned position) const {
|
||||
return 1 + numDims + position;
|
||||
}
|
||||
|
||||
/// Get the position in the matrix that corresponds to the given temporary.
|
||||
unsigned getTemporaryPosition(unsigned position) const {
|
||||
return 1 + numDims + numSymbols + position;
|
||||
}
|
||||
|
||||
/// Number of dimensions in the system,
|
||||
unsigned numDims;
|
||||
/// Number of symbols in the system.
|
||||
unsigned numSymbols;
|
||||
/// Number of temporary variables in the system.
|
||||
unsigned numTemporaries;
|
||||
|
||||
/// Difference bounds matrix, stored as a linearized row-major vector.
|
||||
/// Each value in this matrix corresponds to an inequality
|
||||
///
|
||||
/// v@col - v@row <= at(row, col)
|
||||
///
|
||||
/// where v@col and v@row are the variables that correspond to the linearized
|
||||
/// position in the matrix. The positions correspond to
|
||||
///
|
||||
/// - constant 0 (producing constraints v@col <= X and -v@row <= Y);
|
||||
/// - SDBM expression dimensions (d0, d1, ...);
|
||||
/// - SDBM expression symbols (s0, s1, ...);
|
||||
/// - temporary variables (t0, t1, ...).
|
||||
///
|
||||
/// Temporary variables are introduced to represent expressions that are not
|
||||
/// trivially a difference between two variables. For example, if one side of
|
||||
/// a difference expression is itself a stripe expression, it will be replaced
|
||||
/// with a temporary variable assigned equal to this expression.
|
||||
///
|
||||
/// Infinite entries in the matrix correspond correspond to an absence of a
|
||||
/// constraint:
|
||||
///
|
||||
/// v@col - v@row <= infinity
|
||||
///
|
||||
/// is trivially true. Negated values at symmetric positions in the matrix
|
||||
/// allow one to couple two inequalities into a single equality.
|
||||
std::vector<IntInfty> matrix;
|
||||
|
||||
/// The mapping between the indices of variables in the DBM and the stripe
|
||||
/// expressions they are equal to. These expressions are stored as they
|
||||
/// appeared when constructing an SDBM from a SDBMExprs, in particular no
|
||||
/// temporaries can appear in these expressions. This removes the need to
|
||||
/// iteratively substitute definitions of the temporaries in the reverse
|
||||
/// conversion.
|
||||
DenseMap<unsigned, SDBMExpr> stripeToPoint;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SDBM_SDBM_H
|
|
@ -1,37 +0,0 @@
|
|||
//===- SDBMDialect.h - Dialect for striped DBMs -----------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SDBM_SDBMDIALECT_H
|
||||
#define MLIR_DIALECT_SDBM_SDBMDIALECT_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Support/StorageUniquer.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
|
||||
class SDBMDialect : public Dialect {
|
||||
public:
|
||||
SDBMDialect(MLIRContext *context);
|
||||
|
||||
/// Since there are no other virtual methods in this derived class, override
|
||||
/// the destructor so that key methods get defined in the corresponding
|
||||
/// module.
|
||||
~SDBMDialect() override;
|
||||
|
||||
static StringRef getDialectNamespace() { return "sdbm"; }
|
||||
|
||||
/// Get the uniquer for SDBM expressions. This should not be used directly.
|
||||
StorageUniquer &getUniquer() { return uniquer; }
|
||||
|
||||
private:
|
||||
StorageUniquer uniquer;
|
||||
};
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SDBM_SDBMDIALECT_H
|
|
@ -1,576 +0,0 @@
|
|||
//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// A striped difference-bound matrix (SDBM) expression is a constant expression,
|
||||
// an identifier, a binary expression with constant RHS and +, stripe operators
|
||||
// or a difference expression between two identifiers.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SDBM_SDBMEXPR_H
|
||||
#define MLIR_DIALECT_SDBM_SDBMEXPR_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AffineExpr;
|
||||
class MLIRContext;
|
||||
|
||||
enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg };
|
||||
|
||||
namespace detail {
|
||||
struct SDBMExprStorage;
|
||||
struct SDBMBinaryExprStorage;
|
||||
struct SDBMDiffExprStorage;
|
||||
struct SDBMTermExprStorage;
|
||||
struct SDBMConstantExprStorage;
|
||||
struct SDBMNegExprStorage;
|
||||
} // namespace detail
|
||||
|
||||
class SDBMConstantExpr;
|
||||
class SDBMDialect;
|
||||
class SDBMDimExpr;
|
||||
class SDBMSymbolExpr;
|
||||
class SDBMTermExpr;
|
||||
|
||||
/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side
|
||||
/// expression for the SDBM framework. SDBM expressions are a subset of affine
|
||||
/// expressions supporting low-complexity algorithms for the operations used in
|
||||
/// loop transformations. In particular, are supported:
|
||||
/// - constant expressions;
|
||||
/// - single variables (dimensions and symbols) with +1 or -1 coefficient;
|
||||
/// - stripe expressions: "x # C", where "x" is a single variable or another
|
||||
/// stripe expression, "#" is the stripe operator, and "C" is a constant
|
||||
/// expression; "#" is defined as x - x mod C.
|
||||
/// - sum expressions between single variable/stripe expressions and constant
|
||||
/// expressions;
|
||||
/// - difference expressions between single variable/stripe expressions.
|
||||
/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing
|
||||
/// and operating on SDBM expressions. For example, it requires the LHS of a
|
||||
/// sum expression to be a single variable or a stripe expression. These
|
||||
/// restrictions are intended to force the caller to perform the necessary
|
||||
/// simplifications to stay within the SDBM domain, because SDBM expressions do
|
||||
/// not combine in more cases than they do. This choice may be reconsidered in
|
||||
/// the future.
|
||||
///
|
||||
/// SDBM expressions are grouped into the following structure
|
||||
/// - expression
|
||||
/// - varying
|
||||
/// - direct
|
||||
/// - sum <- (term, constant)
|
||||
/// - term
|
||||
/// - symbol
|
||||
/// - dimension
|
||||
/// - stripe <- (direct, constant)
|
||||
/// - negation <- (direct)
|
||||
/// - difference <- (direct, term)
|
||||
/// - constant
|
||||
/// The notation <- (...) denotes the types of subexpressions a compound
|
||||
/// expression can combine. The tree of subexpressions essentially imposes the
|
||||
/// following canonicalization rules:
|
||||
/// - constants are always folded;
|
||||
/// - constants can only appear on the RHS of an expression;
|
||||
/// - double negation must be elided;
|
||||
/// - an additive constant term is only allowed in a sum expression, and
|
||||
/// should be sunk into the nearest such expression in the tree;
|
||||
/// - zero constant expression can only appear at the top level.
|
||||
///
|
||||
/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
|
||||
/// an MLIRContext, and should be used by-value. They are uniqued in the
|
||||
/// MLIRContext and immortal.
|
||||
class SDBMExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMExprStorage;
|
||||
SDBMExpr() : impl(nullptr) {}
|
||||
/* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {}
|
||||
|
||||
/// SDBM expressions are thin wrappers around a unique'ed immutable pointer,
|
||||
/// which makes them trivially assignable and trivially copyable.
|
||||
SDBMExpr(const SDBMExpr &) = default;
|
||||
SDBMExpr &operator=(const SDBMExpr &) = default;
|
||||
|
||||
/// SDBM expressions can be compared straight-forwardly.
|
||||
bool operator==(const SDBMExpr &other) const { return impl == other.impl; }
|
||||
bool operator!=(const SDBMExpr &other) const { return !(*this == other); }
|
||||
|
||||
/// SDBM expressions are convertible to `bool`: null expressions are converted
|
||||
/// to false, non-null expressions are converted to true.
|
||||
explicit operator bool() const { return impl != nullptr; }
|
||||
bool operator!() const { return !static_cast<bool>(*this); }
|
||||
|
||||
/// Negate the given SDBM expression.
|
||||
SDBMExpr operator-();
|
||||
|
||||
/// Prints the SDBM expression.
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
/// LLVM-style casts.
|
||||
template <typename U> bool isa() const { return U::isClassFor(*this); }
|
||||
template <typename U> U dyn_cast() const {
|
||||
if (!isa<U>())
|
||||
return {};
|
||||
return U(const_cast<SDBMExpr *>(this)->impl);
|
||||
}
|
||||
template <typename U> U cast() const {
|
||||
assert(isa<U>() && "cast to incorrect subtype");
|
||||
return U(const_cast<SDBMExpr *>(this)->impl);
|
||||
}
|
||||
|
||||
/// Support for LLVM hashing.
|
||||
::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); }
|
||||
|
||||
/// Returns the kind of the SDBM expression.
|
||||
SDBMExprKind getKind() const;
|
||||
|
||||
/// Returns the MLIR context in which this expression lives.
|
||||
MLIRContext *getContext() const;
|
||||
|
||||
/// Returns the SDBM dialect instance.
|
||||
SDBMDialect *getDialect() const;
|
||||
|
||||
/// Convert the SDBM expression into an Affine expression. This always
|
||||
/// succeeds because SDBM are a subset of affine.
|
||||
AffineExpr getAsAffineExpr() const;
|
||||
|
||||
/// Try constructing an SDBM expression from the given affine expression.
|
||||
/// This may fail if the affine expression is not representable as SDBM, in
|
||||
/// which case llvm::None is returned. The conversion procedure recognizes
|
||||
/// (nested) multiplicative ((x floordiv B) * B) and additive (x - x mod B)
|
||||
/// patterns for the stripe expression.
|
||||
static Optional<SDBMExpr> tryConvertAffineExpr(AffineExpr affine);
|
||||
|
||||
protected:
|
||||
ImplType *impl;
|
||||
};
|
||||
|
||||
/// SDBM constant expression, wraps a 64-bit integer.
|
||||
class SDBMConstantExpr : public SDBMExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMConstantExprStorage;
|
||||
|
||||
using SDBMExpr::SDBMExpr;
|
||||
|
||||
/// Obtain or create a constant expression unique'ed in the given dialect
|
||||
/// (which belongs to a context).
|
||||
static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value);
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::Constant;
|
||||
}
|
||||
|
||||
int64_t getValue() const;
|
||||
};
|
||||
|
||||
/// SDBM varying expression can be one of:
|
||||
/// - input variable expression;
|
||||
/// - stripe expression;
|
||||
/// - negation (product with -1) of either of the above.
|
||||
/// - sum of a varying and a constant expression
|
||||
/// - difference between varying expressions
|
||||
class SDBMVaryingExpr : public SDBMExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMExprStorage;
|
||||
using SDBMExpr::SDBMExpr;
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::DimId ||
|
||||
expr.getKind() == SDBMExprKind::SymbolId ||
|
||||
expr.getKind() == SDBMExprKind::Neg ||
|
||||
expr.getKind() == SDBMExprKind::Stripe ||
|
||||
expr.getKind() == SDBMExprKind::Add ||
|
||||
expr.getKind() == SDBMExprKind::Diff;
|
||||
}
|
||||
};
|
||||
|
||||
/// SDBM direct expression includes exactly one variable (symbol or dimension),
|
||||
/// which is not negated in the expression. It can be one of:
|
||||
/// - term expression;
|
||||
/// - sum expression.
|
||||
class SDBMDirectExpr : public SDBMVaryingExpr {
|
||||
public:
|
||||
using SDBMVaryingExpr::SDBMVaryingExpr;
|
||||
|
||||
/// If this is a sum expression, return its variable part, otherwise return
|
||||
/// self.
|
||||
SDBMTermExpr getTerm();
|
||||
|
||||
/// If this is a sum expression, return its constant part, otherwise return 0.
|
||||
int64_t getConstant();
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::DimId ||
|
||||
expr.getKind() == SDBMExprKind::SymbolId ||
|
||||
expr.getKind() == SDBMExprKind::Stripe ||
|
||||
expr.getKind() == SDBMExprKind::Add;
|
||||
}
|
||||
};
|
||||
|
||||
/// SDBM term expression can be one of:
|
||||
/// - single variable expression;
|
||||
/// - stripe expression.
|
||||
/// Stripe expressions are treated as terms since, in the SDBM domain, they are
|
||||
/// attached to temporary variables and can appear anywhere a variable can.
|
||||
class SDBMTermExpr : public SDBMDirectExpr {
|
||||
public:
|
||||
using SDBMDirectExpr::SDBMDirectExpr;
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::DimId ||
|
||||
expr.getKind() == SDBMExprKind::SymbolId ||
|
||||
expr.getKind() == SDBMExprKind::Stripe;
|
||||
}
|
||||
};
|
||||
|
||||
/// SDBM sum expression. LHS is a term expression and RHS is a constant.
|
||||
class SDBMSumExpr : public SDBMDirectExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMBinaryExprStorage;
|
||||
using SDBMDirectExpr::SDBMDirectExpr;
|
||||
|
||||
/// Obtain or create a sum expression unique'ed in the given context.
|
||||
static SDBMSumExpr get(SDBMTermExpr lhs, SDBMConstantExpr rhs);
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
SDBMExprKind kind = expr.getKind();
|
||||
return kind == SDBMExprKind::Add;
|
||||
}
|
||||
|
||||
SDBMTermExpr getLHS() const;
|
||||
SDBMConstantExpr getRHS() const;
|
||||
};
|
||||
|
||||
/// SDBM difference expression. LHS is a direct expression, i.e. it may be a
|
||||
/// sum of a term and a constant. RHS is a term expression. Thus the
|
||||
/// expression (t1 - t2 + C) with term expressions t1,t2 is represented as
|
||||
/// diff(sum(t1, C), t2)
|
||||
/// and it is possible to extract the constant factor without negating it.
|
||||
class SDBMDiffExpr : public SDBMVaryingExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMDiffExprStorage;
|
||||
using SDBMVaryingExpr::SDBMVaryingExpr;
|
||||
|
||||
/// Obtain or create a difference expression unique'ed in the given context.
|
||||
static SDBMDiffExpr get(SDBMDirectExpr lhs, SDBMTermExpr rhs);
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::Diff;
|
||||
}
|
||||
|
||||
SDBMDirectExpr getLHS() const;
|
||||
SDBMTermExpr getRHS() const;
|
||||
};
|
||||
|
||||
/// SDBM stripe expression "x # C" where "x" is a term expression, "C" is a
|
||||
/// constant expression and "#" is the stripe operator defined as:
|
||||
/// x # C = x - x mod C.
|
||||
class SDBMStripeExpr : public SDBMTermExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMBinaryExprStorage;
|
||||
using SDBMTermExpr::SDBMTermExpr;
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::Stripe;
|
||||
}
|
||||
|
||||
static SDBMStripeExpr get(SDBMDirectExpr var, SDBMConstantExpr stripeFactor);
|
||||
|
||||
SDBMDirectExpr getLHS() const;
|
||||
SDBMConstantExpr getStripeFactor() const;
|
||||
};
|
||||
|
||||
/// SDBM "input" variable expression can be either a dimension identifier or
|
||||
/// a symbol identifier. When used to define SDBM functions, dimensions are
|
||||
/// interpreted as function arguments while symbols are treated as unknown but
|
||||
/// constant values, hence the name.
|
||||
class SDBMInputExpr : public SDBMTermExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMTermExprStorage;
|
||||
using SDBMTermExpr::SDBMTermExpr;
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::DimId ||
|
||||
expr.getKind() == SDBMExprKind::SymbolId;
|
||||
}
|
||||
|
||||
unsigned getPosition() const;
|
||||
};
|
||||
|
||||
/// SDBM dimension expression. Dimensions correspond to function arguments
|
||||
/// when defining functions using SDBM expressions.
|
||||
class SDBMDimExpr : public SDBMInputExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMTermExprStorage;
|
||||
using SDBMInputExpr::SDBMInputExpr;
|
||||
|
||||
/// Obtain or create a dimension expression unique'ed in the given dialect
|
||||
/// (which belongs to a context).
|
||||
static SDBMDimExpr get(SDBMDialect *dialect, unsigned position);
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::DimId;
|
||||
}
|
||||
};
|
||||
|
||||
/// SDBM symbol expression. Symbols correspond to symbolic constants when
|
||||
/// defining functions using SDBM expressions.
|
||||
class SDBMSymbolExpr : public SDBMInputExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMTermExprStorage;
|
||||
using SDBMInputExpr::SDBMInputExpr;
|
||||
|
||||
/// Obtain or create a symbol expression unique'ed in the given dialect (which
|
||||
/// belongs to a context).
|
||||
static SDBMSymbolExpr get(SDBMDialect *dialect, unsigned position);
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::SymbolId;
|
||||
}
|
||||
};
|
||||
|
||||
/// Negation of an SDBM variable expression. Equivalent to multiplying the
|
||||
/// expression with -1 (SDBM does not support other coefficients that 1 and -1).
|
||||
class SDBMNegExpr : public SDBMVaryingExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMNegExprStorage;
|
||||
using SDBMVaryingExpr::SDBMVaryingExpr;
|
||||
|
||||
/// Obtain or create a negation expression unique'ed in the given context.
|
||||
static SDBMNegExpr get(SDBMDirectExpr var);
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::Neg;
|
||||
}
|
||||
|
||||
SDBMDirectExpr getVar() const;
|
||||
};
|
||||
|
||||
/// A visitor class for SDBM expressions. Calls the kind-specific function
|
||||
/// depending on the kind of expression it visits.
|
||||
template <typename Derived, typename Result = void> class SDBMVisitor {
|
||||
public:
|
||||
/// Visit the given SDBM expression, dispatching to kind-specific functions.
|
||||
Result visit(SDBMExpr expr) {
|
||||
auto *derived = static_cast<Derived *>(this);
|
||||
switch (expr.getKind()) {
|
||||
case SDBMExprKind::Add:
|
||||
case SDBMExprKind::Diff:
|
||||
case SDBMExprKind::DimId:
|
||||
case SDBMExprKind::SymbolId:
|
||||
case SDBMExprKind::Neg:
|
||||
case SDBMExprKind::Stripe:
|
||||
return derived->visitVarying(expr.cast<SDBMVaryingExpr>());
|
||||
case SDBMExprKind::Constant:
|
||||
return derived->visitConstant(expr.cast<SDBMConstantExpr>());
|
||||
}
|
||||
|
||||
llvm_unreachable("unsupported SDBM expression kind");
|
||||
}
|
||||
|
||||
/// Traverse the SDBM expression tree calling `visit` on each node
|
||||
/// in depth-first preorder.
|
||||
void walkPreorder(SDBMExpr expr) { return walk</*isPreorder=*/true>(expr); }
|
||||
|
||||
/// Traverse the SDBM expression tree calling `visit` on each node in
|
||||
/// depth-first postorder.
|
||||
void walkPostorder(SDBMExpr expr) { return walk</*isPreorder=*/false>(expr); }
|
||||
|
||||
protected:
|
||||
/// Default visitors do nothing.
|
||||
void visitSum(SDBMSumExpr) {}
|
||||
void visitDiff(SDBMDiffExpr) {}
|
||||
void visitStripe(SDBMStripeExpr) {}
|
||||
void visitDim(SDBMDimExpr) {}
|
||||
void visitSymbol(SDBMSymbolExpr) {}
|
||||
void visitNeg(SDBMNegExpr) {}
|
||||
void visitConstant(SDBMConstantExpr) {}
|
||||
|
||||
/// Default implementation of visitDirect dispatches to the dedicated for sums
|
||||
/// or delegates to visitTerm for the other expression kinds. Concrete
|
||||
/// visitors can overload it.
|
||||
Result visitDirect(SDBMDirectExpr expr) {
|
||||
auto *derived = static_cast<Derived *>(this);
|
||||
if (auto sum = expr.dyn_cast<SDBMSumExpr>())
|
||||
return derived->visitSum(sum);
|
||||
else
|
||||
return derived->visitTerm(expr.cast<SDBMTermExpr>());
|
||||
}
|
||||
|
||||
/// Default implementation of visitTerm dispatches to the special functions
|
||||
/// for stripes and other variables. Concrete visitors can override it.
|
||||
Result visitTerm(SDBMTermExpr expr) {
|
||||
auto *derived = static_cast<Derived *>(this);
|
||||
if (expr.getKind() == SDBMExprKind::Stripe)
|
||||
return derived->visitStripe(expr.cast<SDBMStripeExpr>());
|
||||
else
|
||||
return derived->visitInput(expr.cast<SDBMInputExpr>());
|
||||
}
|
||||
|
||||
/// Default implementation of visitInput dispatches to the special
|
||||
/// functions for dimensions or symbols. Concrete visitors can override it to
|
||||
/// visit all variables instead.
|
||||
Result visitInput(SDBMInputExpr expr) {
|
||||
auto *derived = static_cast<Derived *>(this);
|
||||
if (expr.getKind() == SDBMExprKind::DimId)
|
||||
return derived->visitDim(expr.cast<SDBMDimExpr>());
|
||||
else
|
||||
return derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
|
||||
}
|
||||
|
||||
/// Default implementation of visitVarying dispatches to the special
|
||||
/// functions for variables and negations thereof. Concrete visitors can
|
||||
/// override it to visit all variables and negations instead.
|
||||
Result visitVarying(SDBMVaryingExpr expr) {
|
||||
auto *derived = static_cast<Derived *>(this);
|
||||
if (auto var = expr.dyn_cast<SDBMDirectExpr>())
|
||||
return derived->visitDirect(var);
|
||||
else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
|
||||
return derived->visitNeg(neg);
|
||||
else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
|
||||
return derived->visitDiff(diff);
|
||||
|
||||
llvm_unreachable("unhandled subtype of varying SDBM expression");
|
||||
}
|
||||
|
||||
template <bool isPreorder> void walk(SDBMExpr expr) {
|
||||
if (isPreorder)
|
||||
visit(expr);
|
||||
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
|
||||
walk<isPreorder>(sumExpr.getLHS());
|
||||
walk<isPreorder>(sumExpr.getRHS());
|
||||
} else if (auto diffExpr = expr.dyn_cast<SDBMDiffExpr>()) {
|
||||
walk<isPreorder>(diffExpr.getLHS());
|
||||
walk<isPreorder>(diffExpr.getRHS());
|
||||
} else if (auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>()) {
|
||||
walk<isPreorder>(stripeExpr.getLHS());
|
||||
walk<isPreorder>(stripeExpr.getStripeFactor());
|
||||
} else if (auto negExpr = expr.dyn_cast<SDBMNegExpr>()) {
|
||||
walk<isPreorder>(negExpr.getVar());
|
||||
}
|
||||
if (!isPreorder)
|
||||
visit(expr);
|
||||
}
|
||||
};
|
||||
|
||||
/// Overloaded arithmetic operators for SDBM expressions asserting that their
|
||||
/// arguments have the proper SDBM expression subtype. Perform canonicalization
|
||||
/// and constant folding on these expressions.
|
||||
namespace ops_assertions {
|
||||
|
||||
/// Add two SDBM expressions. At least one of the expressions must be a
|
||||
/// constant or a negation, but both expressions cannot be negations
|
||||
/// simultaneously.
|
||||
SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs);
|
||||
inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) {
|
||||
return lhs + SDBMConstantExpr::get(lhs.getDialect(), rhs);
|
||||
}
|
||||
inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) {
|
||||
return SDBMConstantExpr::get(rhs.getDialect(), lhs) + rhs;
|
||||
}
|
||||
|
||||
/// Subtract an SDBM expression from another SDBM expression. Both expressions
|
||||
/// must not be difference expressions.
|
||||
SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs);
|
||||
inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) {
|
||||
return lhs - SDBMConstantExpr::get(lhs.getDialect(), rhs);
|
||||
}
|
||||
inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) {
|
||||
return SDBMConstantExpr::get(rhs.getDialect(), lhs) - rhs;
|
||||
}
|
||||
|
||||
/// Construct a stripe expression from a positive expression and a positive
|
||||
/// constant stripe factor.
|
||||
SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor);
|
||||
inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) {
|
||||
return stripe(expr, SDBMConstantExpr::get(expr.getDialect(), factor));
|
||||
}
|
||||
} // namespace ops_assertions
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
// SDBMExpr hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::SDBMExpr> {
|
||||
static mlir::SDBMExpr getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||
}
|
||||
static mlir::SDBMExpr getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||
}
|
||||
static unsigned getHashValue(mlir::SDBMExpr expr) {
|
||||
return expr.hash_value();
|
||||
}
|
||||
static bool isEqual(mlir::SDBMExpr lhs, mlir::SDBMExpr rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
// SDBMDirectExpr hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::SDBMDirectExpr> {
|
||||
static mlir::SDBMDirectExpr getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::SDBMDirectExpr(
|
||||
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||
}
|
||||
static mlir::SDBMDirectExpr getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlir::SDBMDirectExpr(
|
||||
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||
}
|
||||
static unsigned getHashValue(mlir::SDBMDirectExpr expr) {
|
||||
return expr.hash_value();
|
||||
}
|
||||
static bool isEqual(mlir::SDBMDirectExpr lhs, mlir::SDBMDirectExpr rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
// SDBMTermExpr hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::SDBMTermExpr> {
|
||||
static mlir::SDBMTermExpr getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::SDBMTermExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||
}
|
||||
static mlir::SDBMTermExpr getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlir::SDBMTermExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||
}
|
||||
static unsigned getHashValue(mlir::SDBMTermExpr expr) {
|
||||
return expr.hash_value();
|
||||
}
|
||||
static bool isEqual(mlir::SDBMTermExpr lhs, mlir::SDBMTermExpr rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
// SDBMConstantExpr hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::SDBMConstantExpr> {
|
||||
static mlir::SDBMConstantExpr getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::SDBMConstantExpr(
|
||||
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||
}
|
||||
static mlir::SDBMConstantExpr getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlir::SDBMConstantExpr(
|
||||
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||
}
|
||||
static unsigned getHashValue(mlir::SDBMConstantExpr expr) {
|
||||
return expr.hash_value();
|
||||
}
|
||||
static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
#endif // MLIR_DIALECT_SDBM_SDBMEXPR_H
|
|
@ -35,7 +35,6 @@
|
|||
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SDBM/SDBMDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
|
@ -75,7 +74,6 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
|||
vector::VectorDialect,
|
||||
NVVM::NVVMDialect,
|
||||
ROCDL::ROCDLDialect,
|
||||
SDBMDialect,
|
||||
shape::ShapeDialect,
|
||||
sparse_tensor::SparseTensorDialect,
|
||||
tensor::TensorDialect,
|
||||
|
|
|
@ -17,7 +17,6 @@ add_subdirectory(PDL)
|
|||
add_subdirectory(PDLInterp)
|
||||
add_subdirectory(Quant)
|
||||
add_subdirectory(SCF)
|
||||
add_subdirectory(SDBM)
|
||||
add_subdirectory(Shape)
|
||||
add_subdirectory(SparseTensor)
|
||||
add_subdirectory(SPIRV)
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
add_mlir_dialect_library(MLIRSDBM
|
||||
SDBM.cpp
|
||||
SDBMDialect.cpp
|
||||
SDBMExpr.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SDBM
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
)
|
|
@ -1,551 +0,0 @@
|
|||
//===- SDBM.cpp - MLIR SDBM implementation --------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined
|
||||
// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SDBM/SDBM.h"
|
||||
#include "mlir/Dialect/SDBM/SDBMExpr.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// Helper function for SDBM construction that collects information necessary to
|
||||
// start building an SDBM in one sweep. In particular, it records the largest
|
||||
// position of a dimension in `dim`, that of a symbol in `symbol` as well as
|
||||
// collects all unique stripe expressions in `stripes`. Uses SetVector to
|
||||
// ensure these expressions always have the same order.
|
||||
static void collectSDBMBuildInfo(SDBMExpr expr, int &dim, int &symbol,
|
||||
llvm::SmallSetVector<SDBMExpr, 8> &stripes) {
|
||||
struct Visitor : public SDBMVisitor<Visitor> {
|
||||
void visitDim(SDBMDimExpr dimExpr) {
|
||||
int p = dimExpr.getPosition();
|
||||
if (p > maxDimPosition)
|
||||
maxDimPosition = p;
|
||||
}
|
||||
void visitSymbol(SDBMSymbolExpr symbExpr) {
|
||||
int p = symbExpr.getPosition();
|
||||
if (p > maxSymbPosition)
|
||||
maxSymbPosition = p;
|
||||
}
|
||||
void visitStripe(SDBMStripeExpr stripeExpr) { stripes.insert(stripeExpr); }
|
||||
|
||||
Visitor(llvm::SmallSetVector<SDBMExpr, 8> &stripes) : stripes(stripes) {}
|
||||
|
||||
int maxDimPosition = -1;
|
||||
int maxSymbPosition = -1;
|
||||
llvm::SmallSetVector<SDBMExpr, 8> &stripes;
|
||||
};
|
||||
|
||||
Visitor visitor(stripes);
|
||||
visitor.walkPostorder(expr);
|
||||
dim = std::max(dim, visitor.maxDimPosition);
|
||||
symbol = std::max(symbol, visitor.maxSymbPosition);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Utility class for SDBMBuilder. Represents a value that can be inserted in
|
||||
// the SDB matrix that corresponds to "v0 - v1 + C <= 0", where v0 and v1 is
|
||||
// any combination of the positive and negative positions. Since multiple
|
||||
// variables can be declared equal to the same stripe expression, the
|
||||
// constraints on this expression must be reflected to all these variables. For
|
||||
// example, if
|
||||
// d0 = s0 # 42
|
||||
// d1 = s0 # 42
|
||||
// d2 = s1 # 2
|
||||
// d3 = s1 # 2
|
||||
// the constraint
|
||||
// s0 # 42 - s1 # 2 <= C
|
||||
// should be reflected in the DB matrix as
|
||||
// d0 - d2 <= C
|
||||
// d1 - d2 <= C
|
||||
// d0 - d3 <= C
|
||||
// d1 - d3 <= C
|
||||
// since the DB matrix has no knowledge of the transitive equality between d0,
|
||||
// d1 and s0 # 42 as well as between d2, d3 and s1 # 2. This knowledge can be
|
||||
// obtained by computing a transitive closure, which is impossible until the
|
||||
// DBM is actually built.
|
||||
struct SDBMBuilderResult {
|
||||
// Positions in the matrix of the variables taken with the "+" sign in the
|
||||
// difference expression, 0 if it is a constant rather than a variable.
|
||||
SmallVector<unsigned, 2> positivePos;
|
||||
|
||||
// Positions in the matrix of the variables taken with the "-" sign in the
|
||||
// difference expression, 0 if it is a constant rather than a variable.
|
||||
SmallVector<unsigned, 2> negativePos;
|
||||
|
||||
// Constant value in the difference expression.
|
||||
int64_t value = 0;
|
||||
};
|
||||
|
||||
// Visitor for building an SDBM from SDBM expressions. After traversing an SDBM
|
||||
// expression, produces an update to the SDB matrix specifying the positions in
|
||||
// the matrix and the negated value that should be stored. Both the positive
|
||||
// and the negative positions may be lists of indices in cases where multiple
|
||||
// variables are equal to the same stripe expression. In such cases, the update
|
||||
// applies to the cross product of positions because elements involved in the
|
||||
// update are (transitively) equal and should have the same constraints, but we
|
||||
// may not have an explicit equality for them.
|
||||
struct SDBMBuilder : public SDBMVisitor<SDBMBuilder, SDBMBuilderResult> {
|
||||
public:
|
||||
// A difference expression produces both the positive and the negative
|
||||
// coordinate in the matrix, recursively traversing the LHS and the RHS. The
|
||||
// value is the difference between values obtained from LHS and RHS.
|
||||
SDBMBuilderResult visitDiff(SDBMDiffExpr diffExpr) {
|
||||
auto lhs = visit(diffExpr.getLHS());
|
||||
auto rhs = visit(diffExpr.getRHS());
|
||||
assert(lhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 &&
|
||||
"unexpected negative expression in a difference expression");
|
||||
assert(rhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 &&
|
||||
"unexpected negative expression in a difference expression");
|
||||
|
||||
SDBMBuilderResult result;
|
||||
result.positivePos = lhs.positivePos;
|
||||
result.negativePos = rhs.positivePos;
|
||||
result.value = lhs.value - rhs.value;
|
||||
return result;
|
||||
}
|
||||
|
||||
// An input expression is always taken with the "+" sign and therefore
|
||||
// produces a positive coordinate keeping the negative coordinate zero for an
|
||||
// eventual constant.
|
||||
SDBMBuilderResult visitInput(SDBMInputExpr expr) {
|
||||
SDBMBuilderResult r;
|
||||
r.positivePos.push_back(linearPosition(expr));
|
||||
r.negativePos.push_back(0);
|
||||
return r;
|
||||
}
|
||||
|
||||
// A stripe expression is always equal to one or more variables, which may be
|
||||
// temporaries, and appears with a "+" sign in the SDBM expression tree. Take
|
||||
// the positions of the corresponding variables as positive coordinates.
|
||||
SDBMBuilderResult visitStripe(SDBMStripeExpr expr) {
|
||||
SDBMBuilderResult r;
|
||||
assert(pointExprToStripe.count(expr));
|
||||
r.positivePos = pointExprToStripe[expr];
|
||||
r.negativePos.push_back(0);
|
||||
return r;
|
||||
}
|
||||
|
||||
// A constant expression has both coordinates at zero.
|
||||
SDBMBuilderResult visitConstant(SDBMConstantExpr expr) {
|
||||
SDBMBuilderResult r;
|
||||
r.positivePos.push_back(0);
|
||||
r.negativePos.push_back(0);
|
||||
r.value = expr.getValue();
|
||||
return r;
|
||||
}
|
||||
|
||||
// A negation expression swaps the positive and the negative coordinates
|
||||
// and also negates the constant value.
|
||||
SDBMBuilderResult visitNeg(SDBMNegExpr expr) {
|
||||
SDBMBuilderResult result = visit(expr.getVar());
|
||||
std::swap(result.positivePos, result.negativePos);
|
||||
result.value = -result.value;
|
||||
return result;
|
||||
}
|
||||
|
||||
// The RHS of a sum expression must be a constant and therefore must have both
|
||||
// positive and negative coordinates at zero. Take the sum of the values
|
||||
// between LHS and RHS and keep LHS coordinates.
|
||||
SDBMBuilderResult visitSum(SDBMSumExpr expr) {
|
||||
auto lhs = visit(expr.getLHS());
|
||||
auto rhs = visit(expr.getRHS());
|
||||
for (auto pos : rhs.negativePos) {
|
||||
(void)pos;
|
||||
assert(pos == 0 && "unexpected variable on the RHS of SDBM sum");
|
||||
}
|
||||
for (auto pos : rhs.positivePos) {
|
||||
(void)pos;
|
||||
assert(pos == 0 && "unexpected variable on the RHS of SDBM sum");
|
||||
}
|
||||
|
||||
lhs.value += rhs.value;
|
||||
return lhs;
|
||||
}
|
||||
|
||||
SDBMBuilder(DenseMap<SDBMExpr, SmallVector<unsigned, 2>> &pointExprToStripe,
|
||||
function_ref<unsigned(SDBMInputExpr)> callback)
|
||||
: pointExprToStripe(pointExprToStripe), linearPosition(callback) {}
|
||||
|
||||
DenseMap<SDBMExpr, SmallVector<unsigned, 2>> &pointExprToStripe;
|
||||
function_ref<unsigned(SDBMInputExpr)> linearPosition;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
SDBM SDBM::get(ArrayRef<SDBMExpr> inequalities, ArrayRef<SDBMExpr> equalities) {
|
||||
SDBM result;
|
||||
|
||||
// TODO: consider detecting equalities in the list of inequalities.
|
||||
// This is potentially expensive and requires to
|
||||
// - create a list of negated inequalities (may allocate under lock);
|
||||
// - perform a pairwise comparison of direct and negated inequalities;
|
||||
// - copy the lists of equalities and inequalities, and move entries between
|
||||
// them;
|
||||
// only for the purpose of sparing a temporary variable in cases where an
|
||||
// implicit equality between a variable and a stripe expression is present in
|
||||
// the input.
|
||||
|
||||
// Do the first sweep over (in)equalities to collect the information necessary
|
||||
// to allocate the SDB matrix (number of dimensions, symbol and temporary
|
||||
// variables required for stripe expressions).
|
||||
llvm::SmallSetVector<SDBMExpr, 8> stripes;
|
||||
int maxDim = -1;
|
||||
int maxSymbol = -1;
|
||||
for (auto expr : inequalities)
|
||||
collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes);
|
||||
for (auto expr : equalities)
|
||||
collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes);
|
||||
// Indexing of dimensions starts with 0, obtain the number of dimensions by
|
||||
// incrementing the maximal position of the dimension seen in expressions.
|
||||
result.numDims = maxDim + 1;
|
||||
result.numSymbols = maxSymbol + 1;
|
||||
result.numTemporaries = 0;
|
||||
|
||||
// Helper function that returns the position of the variable represented by
|
||||
// an SDBM input expression.
|
||||
auto linearPosition = [result](SDBMInputExpr expr) {
|
||||
if (expr.isa<SDBMDimExpr>())
|
||||
return result.getDimPosition(expr.getPosition());
|
||||
return result.getSymbolPosition(expr.getPosition());
|
||||
};
|
||||
|
||||
// Check if some stripe expressions are equal to another variable. In
|
||||
// particular, look for the equalities of the form
|
||||
// d0 - stripe-expression = 0, or
|
||||
// stripe-expression - d0 = 0.
|
||||
// There may be multiple variables that are equal to the same stripe
|
||||
// expression. Keep track of those in pointExprToStripe.
|
||||
// There may also be multiple stripe expressions equal to the same variable.
|
||||
// Introduce a temporary variable for each of those.
|
||||
DenseMap<SDBMExpr, SmallVector<unsigned, 2>> pointExprToStripe;
|
||||
unsigned numTemporaries = 0;
|
||||
|
||||
auto updateStripePointMaps = [&numTemporaries, &result, &pointExprToStripe,
|
||||
linearPosition](SDBMInputExpr input,
|
||||
SDBMExpr expr) {
|
||||
unsigned position = linearPosition(input);
|
||||
if (result.stripeToPoint.count(position) &&
|
||||
result.stripeToPoint[position] != expr) {
|
||||
position = result.getNumVariables() + numTemporaries++;
|
||||
}
|
||||
pointExprToStripe[expr].push_back(position);
|
||||
result.stripeToPoint.insert(std::make_pair(position, expr));
|
||||
};
|
||||
|
||||
for (auto eq : equalities) {
|
||||
auto diffExpr = eq.dyn_cast<SDBMDiffExpr>();
|
||||
if (!diffExpr)
|
||||
continue;
|
||||
|
||||
auto lhs = diffExpr.getLHS();
|
||||
auto rhs = diffExpr.getRHS();
|
||||
auto lhsInput = lhs.dyn_cast<SDBMInputExpr>();
|
||||
auto rhsInput = rhs.dyn_cast<SDBMInputExpr>();
|
||||
|
||||
if (lhsInput && stripes.count(rhs))
|
||||
updateStripePointMaps(lhsInput, rhs);
|
||||
if (rhsInput && stripes.count(lhs))
|
||||
updateStripePointMaps(rhsInput, lhs);
|
||||
}
|
||||
|
||||
// Assign the remaining stripe expressions to temporary variables. These
|
||||
// expressions are the ones that could not be associated with an existing
|
||||
// variable in the previous step.
|
||||
for (auto expr : stripes) {
|
||||
if (pointExprToStripe.count(expr))
|
||||
continue;
|
||||
unsigned position = result.getNumVariables() + numTemporaries++;
|
||||
pointExprToStripe[expr].push_back(position);
|
||||
result.stripeToPoint.insert(std::make_pair(position, expr));
|
||||
}
|
||||
|
||||
// Create the DBM matrix, initialized to infinity values for the least tight
|
||||
// possible bound (x - y <= infinity is always true).
|
||||
result.numTemporaries = numTemporaries;
|
||||
result.matrix.resize(result.getNumVariables() * result.getNumVariables(),
|
||||
IntInfty::infinity());
|
||||
|
||||
SDBMBuilder builder(pointExprToStripe, linearPosition);
|
||||
|
||||
// Only keep the tightest constraint. Since we transform everything into
|
||||
// less-than-or-equals-to inequalities, keep the smallest constant. For
|
||||
// example, if we have d0 - d1 <= 42 and d0 - d1 <= 2, we keep the latter.
|
||||
// Note that the input expressions are in the shape of d0 - d1 + -42 <= 0
|
||||
// so we negate the value before storing it.
|
||||
// In case where the positive and the negative positions are equal, the
|
||||
// corresponding expression has the form d0 - d0 + -42 <= 0. If the constant
|
||||
// value is positive, the set defined by SDBM is trivially empty. We store
|
||||
// this value anyway and continue processing to maintain the correspondence
|
||||
// between the matrix form and the list-of-SDBMExpr form.
|
||||
// TODO: we may want to reconsider this once we have canonicalization
|
||||
// or simplification in place
|
||||
auto updateMatrix = [](SDBM &sdbm, const SDBMBuilderResult &r) {
|
||||
for (auto positivePos : r.positivePos) {
|
||||
for (auto negativePos : r.negativePos) {
|
||||
auto &m = sdbm.at(negativePos, positivePos);
|
||||
m = m < -r.value ? m : -r.value;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Do the second sweep on (in)equalities, updating the SDB matrix to reflect
|
||||
// the constraints.
|
||||
for (auto ineq : inequalities)
|
||||
updateMatrix(result, builder.visit(ineq));
|
||||
|
||||
// An equality f(x) = 0 is represented as a pair of inequalities {f(x) >= 0;
|
||||
// f(x) <= 0} or, alternatively, {-f(x) <= 0 and f(x) <= 0}.
|
||||
for (auto eq : equalities) {
|
||||
updateMatrix(result, builder.visit(eq));
|
||||
updateMatrix(result, builder.visit(-eq));
|
||||
}
|
||||
|
||||
// Add the inequalities induced by stripe equalities.
|
||||
// t = x # C => t <= x <= t + C - 1
|
||||
// which is equivalent to
|
||||
// {t - x <= 0;
|
||||
// x - t - (C - 1) <= 0}.
|
||||
for (const auto &pair : result.stripeToPoint) {
|
||||
auto stripe = pair.second.cast<SDBMStripeExpr>();
|
||||
SDBMBuilderResult update = builder.visit(stripe.getLHS());
|
||||
assert(update.negativePos.size() == 1 && update.negativePos[0] == 0 &&
|
||||
"unexpected negated variable in stripe expression");
|
||||
assert(update.value == 0 &&
|
||||
"unexpected non-zero value in stripe expression");
|
||||
update.negativePos.clear();
|
||||
update.negativePos.push_back(pair.first);
|
||||
update.value = -(stripe.getStripeFactor().getValue() - 1);
|
||||
updateMatrix(result, update);
|
||||
|
||||
std::swap(update.negativePos, update.positivePos);
|
||||
update.value = 0;
|
||||
updateMatrix(result, update);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Given a row and a column position in the square DBM, insert one equality
|
||||
// or up to two inequalities that correspond the entries (col, row) and (row,
|
||||
// col) in the DBM. `rowExpr` and `colExpr` contain the expressions such that
|
||||
// colExpr - rowExpr <= V where V is the value at (row, col) in the DBM.
|
||||
// If one of the expressions is derived from another using a stripe operation,
|
||||
// check if the inequalities induced by the stripe operation subsume the
|
||||
// inequalities defined in the DBM and if so, elide these inequalities.
|
||||
void SDBM::convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr,
|
||||
SDBMTermExpr colExpr,
|
||||
SmallVectorImpl<SDBMExpr> &inequalities,
|
||||
SmallVectorImpl<SDBMExpr> &equalities) {
|
||||
using ops_assertions::operator+;
|
||||
using ops_assertions::operator-;
|
||||
|
||||
auto diffIJValue = at(col, row);
|
||||
auto diffJIValue = at(row, col);
|
||||
|
||||
// If symmetric entries are opposite, the corresponding expressions are equal.
|
||||
if (diffIJValue.isFinite() &&
|
||||
diffIJValue.getValue() == -diffJIValue.getValue()) {
|
||||
equalities.push_back(rowExpr - colExpr - diffIJValue.getValue());
|
||||
return;
|
||||
}
|
||||
|
||||
// Given an inequality x0 - x1 <= A, check if x0 is a stripe variable derived
|
||||
// from x1: x0 = x1 # B. If so, it would imply the constraints
|
||||
// x0 <= x1 <= x0 + (B - 1) <=> x0 - x1 <= 0 and x1 - x0 <= (B - 1).
|
||||
// Therefore, if A >= 0, this inequality is subsumed by that implied
|
||||
// by the stripe equality and thus can be elided.
|
||||
// Similarly, check if x1 is a stripe variable derived from x0: x1 = x0 # C.
|
||||
// If so, it would imply the constraints x1 <= x0 <= x1 + (C - 1) <=>
|
||||
// <=> x1 - x0 <= 0 and x0 - x1 <= (C - 1). Therefore, if A >= (C - 1), this
|
||||
// inequality can be elided.
|
||||
//
|
||||
// Note: x0 and x1 may be a stripe expressions themselves, we rely on stripe
|
||||
// expressions being stored without temporaries on the RHS and being passed
|
||||
// into this function as is.
|
||||
auto canElide = [this](unsigned x0, unsigned x1, SDBMExpr x0Expr,
|
||||
SDBMExpr x1Expr, int64_t value) {
|
||||
if (stripeToPoint.count(x0)) {
|
||||
auto stripe = stripeToPoint[x0].cast<SDBMStripeExpr>();
|
||||
SDBMDirectExpr var = stripe.getLHS();
|
||||
if (x1Expr == var && value >= 0)
|
||||
return true;
|
||||
}
|
||||
if (stripeToPoint.count(x1)) {
|
||||
auto stripe = stripeToPoint[x1].cast<SDBMStripeExpr>();
|
||||
SDBMDirectExpr var = stripe.getLHS();
|
||||
if (x0Expr == var && value >= stripe.getStripeFactor().getValue() - 1)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Check row - col.
|
||||
if (diffIJValue.isFinite() &&
|
||||
!canElide(row, col, rowExpr, colExpr, diffIJValue.getValue())) {
|
||||
inequalities.push_back(rowExpr - colExpr - diffIJValue.getValue());
|
||||
}
|
||||
// Check col - row.
|
||||
if (diffJIValue.isFinite() &&
|
||||
!canElide(col, row, colExpr, rowExpr, diffJIValue.getValue())) {
|
||||
inequalities.push_back(colExpr - rowExpr - diffJIValue.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
// The values on the main diagonal correspond to the upper bound on the
|
||||
// difference between a variable and itself: d0 - d0 <= C, or alternatively
|
||||
// to -C <= 0. Only construct the inequalities when C is negative, which
|
||||
// are trivially false but necessary for the returned system of inequalities
|
||||
// to indicate that the set it defines is empty.
|
||||
void SDBM::convertDBMDiagonalElement(unsigned pos, SDBMTermExpr expr,
|
||||
SmallVectorImpl<SDBMExpr> &inequalities) {
|
||||
auto selfDifference = at(pos, pos);
|
||||
if (selfDifference.isFinite() && selfDifference < 0) {
|
||||
auto selfDifferenceValueExpr =
|
||||
SDBMConstantExpr::get(expr.getDialect(), -selfDifference.getValue());
|
||||
inequalities.push_back(selfDifferenceValueExpr);
|
||||
}
|
||||
}
|
||||
|
||||
void SDBM::getSDBMExpressions(SDBMDialect *dialect,
|
||||
SmallVectorImpl<SDBMExpr> &inequalities,
|
||||
SmallVectorImpl<SDBMExpr> &equalities) {
|
||||
using ops_assertions::operator-;
|
||||
using ops_assertions::operator+;
|
||||
|
||||
// Helper function that creates an SDBMInputExpr given the linearized position
|
||||
// of variable in the DBM.
|
||||
auto getInput = [dialect, this](unsigned matrixPos) -> SDBMInputExpr {
|
||||
if (matrixPos < numDims)
|
||||
return SDBMDimExpr::get(dialect, matrixPos);
|
||||
return SDBMSymbolExpr::get(dialect, matrixPos - numDims);
|
||||
};
|
||||
|
||||
// The top-left value corresponds to inequality 0 <= C. If C is negative, the
|
||||
// set defined by SDBM is trivially empty and we add the constraint -C <= 0 to
|
||||
// the list of inequalities. Otherwise, the constraint is trivially true and
|
||||
// we ignore it.
|
||||
auto difference = at(0, 0);
|
||||
if (difference.isFinite() && difference < 0) {
|
||||
inequalities.push_back(
|
||||
SDBMConstantExpr::get(dialect, -difference.getValue()));
|
||||
}
|
||||
|
||||
// Traverse the segment of the matrix that involves non-temporary variables.
|
||||
unsigned numTrueVariables = numDims + numSymbols;
|
||||
for (unsigned i = 0; i < numTrueVariables; ++i) {
|
||||
// The first row and column represent numerical upper and lower bound on
|
||||
// each variable. Transform them into inequalities if they are finite.
|
||||
auto upperBound = at(0, 1 + i);
|
||||
auto lowerBound = at(1 + i, 0);
|
||||
auto inputExpr = getInput(i);
|
||||
if (upperBound.isFinite() &&
|
||||
upperBound.getValue() == -lowerBound.getValue()) {
|
||||
equalities.push_back(inputExpr - upperBound.getValue());
|
||||
} else if (upperBound.isFinite()) {
|
||||
inequalities.push_back(inputExpr - upperBound.getValue());
|
||||
} else if (lowerBound.isFinite()) {
|
||||
inequalities.push_back(-inputExpr - lowerBound.getValue());
|
||||
}
|
||||
|
||||
// Introduce trivially false inequalities if required by diagonal elements.
|
||||
convertDBMDiagonalElement(1 + i, inputExpr, inequalities);
|
||||
|
||||
// Introduce equalities or inequalities between non-temporary variables.
|
||||
for (unsigned j = 0; j < i; ++j) {
|
||||
convertDBMElement(1 + i, 1 + j, getInput(i), getInput(j), inequalities,
|
||||
equalities);
|
||||
}
|
||||
}
|
||||
|
||||
// Add equalities for stripe expressions that define non-temporary
|
||||
// variables. Temporary variables will be substituted into their uses and
|
||||
// should not appear in the resulting equalities.
|
||||
for (const auto &stripePair : stripeToPoint) {
|
||||
unsigned position = stripePair.first;
|
||||
if (position < 1 + numTrueVariables) {
|
||||
equalities.push_back(getInput(position - 1) - stripePair.second);
|
||||
}
|
||||
}
|
||||
|
||||
// Add equalities / inequalities involving temporaries by replacing the
|
||||
// temporaries with stripe expressions that define them.
|
||||
for (unsigned i = 1 + numTrueVariables, e = getNumVariables(); i < e; ++i) {
|
||||
// Mixed constraints involving one temporary (j) and one non-temporary (i)
|
||||
// variable.
|
||||
for (unsigned j = 0; j < numTrueVariables; ++j) {
|
||||
convertDBMElement(i, 1 + j, stripeToPoint[i].cast<SDBMStripeExpr>(),
|
||||
getInput(j), inequalities, equalities);
|
||||
}
|
||||
|
||||
// Constraints involving only temporary variables.
|
||||
for (unsigned j = 1 + numTrueVariables; j < i; ++j) {
|
||||
convertDBMElement(i, j, stripeToPoint[i].cast<SDBMStripeExpr>(),
|
||||
stripeToPoint[j].cast<SDBMStripeExpr>(), inequalities,
|
||||
equalities);
|
||||
}
|
||||
|
||||
// Introduce trivially false inequalities if required by diagonal elements.
|
||||
convertDBMDiagonalElement(i, stripeToPoint[i].cast<SDBMStripeExpr>(),
|
||||
inequalities);
|
||||
}
|
||||
}
|
||||
|
||||
void SDBM::print(raw_ostream &os) {
|
||||
unsigned numVariables = getNumVariables();
|
||||
|
||||
// Helper function that prints the name of the variable given its linearized
|
||||
// position in the DBM.
|
||||
auto getVarName = [this](unsigned matrixPos) -> std::string {
|
||||
if (matrixPos == 0)
|
||||
return "cst";
|
||||
matrixPos -= 1;
|
||||
if (matrixPos < numDims)
|
||||
return std::string(llvm::formatv("d{0}", matrixPos));
|
||||
matrixPos -= numDims;
|
||||
if (matrixPos < numSymbols)
|
||||
return std::string(llvm::formatv("s{0}", matrixPos));
|
||||
matrixPos -= numSymbols;
|
||||
return std::string(llvm::formatv("t{0}", matrixPos));
|
||||
};
|
||||
|
||||
// Header row.
|
||||
os << " cst";
|
||||
for (unsigned i = 1; i < numVariables; ++i) {
|
||||
os << llvm::formatv(" {0,4}", getVarName(i));
|
||||
}
|
||||
os << '\n';
|
||||
|
||||
// Data rows.
|
||||
for (unsigned i = 0; i < numVariables; ++i) {
|
||||
os << llvm::formatv("{0,-4}", getVarName(i));
|
||||
for (unsigned j = 0; j < numVariables; ++j) {
|
||||
IntInfty value = operator()(i, j);
|
||||
if (!value.isFinite())
|
||||
os << " inf";
|
||||
else
|
||||
os << llvm::formatv(" {0,4}", value.getValue());
|
||||
}
|
||||
os << '\n';
|
||||
}
|
||||
|
||||
// Explanation of temporaries.
|
||||
for (const auto &pair : stripeToPoint) {
|
||||
os << getVarName(pair.first) << " = ";
|
||||
pair.second.print(os);
|
||||
os << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
void SDBM::dump() { print(llvm::errs()); }
|
|
@ -1,23 +0,0 @@
|
|||
//===- SDBMDialect.cpp - MLIR SDBM Dialect --------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SDBM/SDBMDialect.h"
|
||||
#include "SDBMExprDetail.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
SDBMDialect::SDBMDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {
|
||||
uniquer.registerParametricStorageType<detail::SDBMBinaryExprStorage>();
|
||||
uniquer.registerParametricStorageType<detail::SDBMConstantExprStorage>();
|
||||
uniquer.registerParametricStorageType<detail::SDBMDiffExprStorage>();
|
||||
uniquer.registerParametricStorageType<detail::SDBMNegExprStorage>();
|
||||
uniquer.registerParametricStorageType<detail::SDBMTermExprStorage>();
|
||||
}
|
||||
|
||||
SDBMDialect::~SDBMDialect() = default;
|
|
@ -1,732 +0,0 @@
|
|||
//===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// A striped difference-bound matrix (SDBM) expression is a constant expression,
|
||||
// an identifier, a binary expression with constant RHS and +, stripe operators
|
||||
// or a difference expression between two identifiers.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SDBM/SDBMExpr.h"
|
||||
#include "SDBMExprDetail.h"
|
||||
#include "mlir/Dialect/SDBM/SDBMDialect.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// A simple compositional matcher for AffineExpr
|
||||
///
|
||||
/// Example usage:
|
||||
///
|
||||
/// ```c++
|
||||
/// AffineExprMatcher x, C, m;
|
||||
/// AffineExprMatcher pattern1 = ((x % C) * m) + x;
|
||||
/// AffineExprMatcher pattern2 = x + ((x % C) * m);
|
||||
/// if (pattern1.match(expr) || pattern2.match(expr)) {
|
||||
/// ...
|
||||
/// }
|
||||
/// ```
|
||||
class AffineExprMatcherStorage;
|
||||
class AffineExprMatcher {
|
||||
public:
|
||||
AffineExprMatcher();
|
||||
AffineExprMatcher(const AffineExprMatcher &other);
|
||||
|
||||
AffineExprMatcher operator+(AffineExprMatcher other) {
|
||||
return AffineExprMatcher(AffineExprKind::Add, *this, other);
|
||||
}
|
||||
AffineExprMatcher operator*(AffineExprMatcher other) {
|
||||
return AffineExprMatcher(AffineExprKind::Mul, *this, other);
|
||||
}
|
||||
AffineExprMatcher floorDiv(AffineExprMatcher other) {
|
||||
return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other);
|
||||
}
|
||||
AffineExprMatcher ceilDiv(AffineExprMatcher other) {
|
||||
return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other);
|
||||
}
|
||||
AffineExprMatcher operator%(AffineExprMatcher other) {
|
||||
return AffineExprMatcher(AffineExprKind::Mod, *this, other);
|
||||
}
|
||||
|
||||
AffineExpr match(AffineExpr expr);
|
||||
AffineExpr matched();
|
||||
Optional<int> getMatchedConstantValue();
|
||||
|
||||
private:
|
||||
AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b);
|
||||
AffineExprKind kind; // only used to match in binary op cases.
|
||||
// A shared_ptr allows multiple references to same matcher storage without
|
||||
// worrying about ownership or dealing with an arena. To be cleaned up if we
|
||||
// go with this.
|
||||
std::shared_ptr<AffineExprMatcherStorage> storage;
|
||||
};
|
||||
|
||||
class AffineExprMatcherStorage {
|
||||
public:
|
||||
AffineExprMatcherStorage() {}
|
||||
AffineExprMatcherStorage(const AffineExprMatcherStorage &other)
|
||||
: subExprs(other.subExprs.begin(), other.subExprs.end()),
|
||||
matched(other.matched) {}
|
||||
AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)
|
||||
: subExprs(exprs.begin(), exprs.end()) {}
|
||||
AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
|
||||
: subExprs({a, b}) {}
|
||||
SmallVector<AffineExprMatcher, 0> subExprs;
|
||||
AffineExpr matched;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
AffineExprMatcher::AffineExprMatcher()
|
||||
: kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {}
|
||||
|
||||
AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other)
|
||||
: kind(other.kind), storage(other.storage) {}
|
||||
|
||||
Optional<int> AffineExprMatcher::getMatchedConstantValue() {
|
||||
if (auto cst = storage->matched.dyn_cast<AffineConstantExpr>())
|
||||
return cst.getValue();
|
||||
return None;
|
||||
}
|
||||
|
||||
AffineExpr AffineExprMatcher::match(AffineExpr expr) {
|
||||
if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) {
|
||||
if (storage->matched)
|
||||
if (storage->matched != expr)
|
||||
return AffineExpr();
|
||||
storage->matched = expr;
|
||||
return storage->matched;
|
||||
}
|
||||
if (kind != expr.getKind()) {
|
||||
return AffineExpr();
|
||||
}
|
||||
if (auto bin = expr.dyn_cast<AffineBinaryOpExpr>()) {
|
||||
if (!storage->subExprs.empty() &&
|
||||
!storage->subExprs[0].match(bin.getLHS())) {
|
||||
return AffineExpr();
|
||||
}
|
||||
if (!storage->subExprs.empty() &&
|
||||
!storage->subExprs[1].match(bin.getRHS())) {
|
||||
return AffineExpr();
|
||||
}
|
||||
if (storage->matched)
|
||||
if (storage->matched != expr)
|
||||
return AffineExpr();
|
||||
storage->matched = expr;
|
||||
return storage->matched;
|
||||
}
|
||||
llvm_unreachable("binary expected");
|
||||
}
|
||||
|
||||
AffineExpr AffineExprMatcher::matched() { return storage->matched; }
|
||||
|
||||
AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
|
||||
AffineExprMatcher b)
|
||||
: kind(k), storage(new AffineExprMatcherStorage(a, b)) {
|
||||
storage->subExprs.push_back(a);
|
||||
storage->subExprs.push_back(b);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SDBMExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
|
||||
|
||||
MLIRContext *SDBMExpr::getContext() const {
|
||||
return impl->dialect->getContext();
|
||||
}
|
||||
|
||||
SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; }
|
||||
|
||||
void SDBMExpr::print(raw_ostream &os) const {
|
||||
struct Printer : public SDBMVisitor<Printer> {
|
||||
Printer(raw_ostream &ostream) : prn(ostream) {}
|
||||
|
||||
void visitSum(SDBMSumExpr expr) {
|
||||
visit(expr.getLHS());
|
||||
prn << " + ";
|
||||
visit(expr.getRHS());
|
||||
}
|
||||
void visitDiff(SDBMDiffExpr expr) {
|
||||
visit(expr.getLHS());
|
||||
prn << " - ";
|
||||
visit(expr.getRHS());
|
||||
}
|
||||
void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
|
||||
void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
|
||||
void visitStripe(SDBMStripeExpr expr) {
|
||||
SDBMDirectExpr lhs = expr.getLHS();
|
||||
bool isTerm = lhs.isa<SDBMTermExpr>();
|
||||
if (!isTerm)
|
||||
prn << '(';
|
||||
visit(lhs);
|
||||
if (!isTerm)
|
||||
prn << ')';
|
||||
prn << " # ";
|
||||
visitConstant(expr.getStripeFactor());
|
||||
}
|
||||
void visitNeg(SDBMNegExpr expr) {
|
||||
bool isSum = expr.getVar().isa<SDBMSumExpr>();
|
||||
prn << '-';
|
||||
if (isSum)
|
||||
prn << '(';
|
||||
visit(expr.getVar());
|
||||
if (isSum)
|
||||
prn << ')';
|
||||
}
|
||||
void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
|
||||
|
||||
raw_ostream &prn;
|
||||
};
|
||||
Printer printer(os);
|
||||
printer.visit(*this);
|
||||
}
|
||||
|
||||
void SDBMExpr::dump() const {
|
||||
print(llvm::errs());
|
||||
llvm::errs() << '\n';
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Helper class to perform negation of an SDBM expression.
|
||||
struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
|
||||
// Any term expression is wrapped into a negation expression.
|
||||
// -(x) = -x
|
||||
SDBMExpr visitDirect(SDBMDirectExpr expr) { return SDBMNegExpr::get(expr); }
|
||||
// A negation expression is unwrapped.
|
||||
// -(-x) = x
|
||||
SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
|
||||
// The value of the constant is negated.
|
||||
SDBMExpr visitConstant(SDBMConstantExpr expr) {
|
||||
return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
|
||||
}
|
||||
|
||||
// Terms of a difference are interchanged. Since only the LHS of a diff
|
||||
// expression is allowed to be a sum with a constant, we need to recreate the
|
||||
// sum with the negated value:
|
||||
// -((x + C) - y) = (y - C) - x.
|
||||
SDBMExpr visitDiff(SDBMDiffExpr expr) {
|
||||
// If the LHS is just a term, we can do straightforward interchange.
|
||||
if (auto term = expr.getLHS().dyn_cast<SDBMTermExpr>())
|
||||
return SDBMDiffExpr::get(expr.getRHS(), term);
|
||||
|
||||
auto sum = expr.getLHS().cast<SDBMSumExpr>();
|
||||
auto cst = visitConstant(sum.getRHS()).cast<SDBMConstantExpr>();
|
||||
return SDBMDiffExpr::get(SDBMSumExpr::get(expr.getRHS(), cst),
|
||||
sum.getLHS());
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SDBMSumExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
|
||||
assert(lhs && "expected SDBM variable expression");
|
||||
assert(rhs && "expected SDBM constant");
|
||||
|
||||
// If LHS of a sum is another sum, fold the constant RHS parts.
|
||||
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
|
||||
lhs = lhsSum.getLHS();
|
||||
rhs = SDBMConstantExpr::get(rhs.getDialect(),
|
||||
rhs.getValue() + lhsSum.getRHS().getValue());
|
||||
}
|
||||
|
||||
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
|
||||
return uniquer.get<detail::SDBMBinaryExprStorage>(
|
||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
|
||||
}
|
||||
|
||||
SDBMTermExpr SDBMSumExpr::getLHS() const {
|
||||
return static_cast<ImplType *>(impl)->lhs.cast<SDBMTermExpr>();
|
||||
}
|
||||
|
||||
SDBMConstantExpr SDBMSumExpr::getRHS() const {
|
||||
return static_cast<ImplType *>(impl)->rhs;
|
||||
}
|
||||
|
||||
AffineExpr SDBMExpr::getAsAffineExpr() const {
|
||||
struct Converter : public SDBMVisitor<Converter, AffineExpr> {
|
||||
AffineExpr visitSum(SDBMSumExpr expr) {
|
||||
AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
AffineExpr visitStripe(SDBMStripeExpr expr) {
|
||||
AffineExpr lhs = visit(expr.getLHS()),
|
||||
rhs = visit(expr.getStripeFactor());
|
||||
return lhs - (lhs % rhs);
|
||||
}
|
||||
|
||||
AffineExpr visitDiff(SDBMDiffExpr expr) {
|
||||
AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
||||
return lhs - rhs;
|
||||
}
|
||||
|
||||
AffineExpr visitDim(SDBMDimExpr expr) {
|
||||
return getAffineDimExpr(expr.getPosition(), expr.getContext());
|
||||
}
|
||||
|
||||
AffineExpr visitSymbol(SDBMSymbolExpr expr) {
|
||||
return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
|
||||
}
|
||||
|
||||
AffineExpr visitNeg(SDBMNegExpr expr) {
|
||||
return getAffineBinaryOpExpr(AffineExprKind::Mul,
|
||||
getAffineConstantExpr(-1, expr.getContext()),
|
||||
visit(expr.getVar()));
|
||||
}
|
||||
|
||||
AffineExpr visitConstant(SDBMConstantExpr expr) {
|
||||
return getAffineConstantExpr(expr.getValue(), expr.getContext());
|
||||
}
|
||||
} converter;
|
||||
return converter.visit(*this);
|
||||
}
|
||||
|
||||
// Given a direct expression `expr`, add the given constant to it and pass the
|
||||
// resulting expression to `builder` before returning its result. If the
|
||||
// expression is already a sum expression, update its constant and extract the
|
||||
// LHS if the constant becomes zero. Otherwise, construct a sum expression.
|
||||
template <typename Result>
|
||||
static Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant,
|
||||
bool negated,
|
||||
function_ref<Result(SDBMDirectExpr)> builder) {
|
||||
SDBMDialect *dialect = expr.getDialect();
|
||||
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
|
||||
if (negated)
|
||||
constant = sumExpr.getRHS().getValue() - constant;
|
||||
else
|
||||
constant += sumExpr.getRHS().getValue();
|
||||
|
||||
if (constant != 0) {
|
||||
auto sum = SDBMSumExpr::get(sumExpr.getLHS(),
|
||||
SDBMConstantExpr::get(dialect, constant));
|
||||
return builder(sum);
|
||||
} else {
|
||||
return builder(sumExpr.getLHS());
|
||||
}
|
||||
}
|
||||
if (constant != 0)
|
||||
return builder(SDBMSumExpr::get(
|
||||
expr.cast<SDBMTermExpr>(),
|
||||
SDBMConstantExpr::get(dialect, negated ? -constant : constant)));
|
||||
return expr;
|
||||
}
|
||||
|
||||
// Construct an expression lhs + constant while maintaining the canonical form
|
||||
// of the SDBM expressions, in particular sink the constant expression to the
|
||||
// nearest sum expression in the left subtree of the expression tree.
|
||||
static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) {
|
||||
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
|
||||
return addConstantAndSink<SDBMExpr>(
|
||||
lhsDiff.getLHS(), constant, /*negated=*/false,
|
||||
[lhsDiff](SDBMDirectExpr e) {
|
||||
return SDBMDiffExpr::get(e, lhsDiff.getRHS());
|
||||
});
|
||||
if (auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>())
|
||||
return addConstantAndSink<SDBMExpr>(
|
||||
lhsNeg.getVar(), constant, /*negated=*/true,
|
||||
[](SDBMDirectExpr e) { return SDBMNegExpr::get(e); });
|
||||
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>())
|
||||
return addConstantAndSink<SDBMExpr>(lhsSum, constant, /*negated=*/false,
|
||||
[](SDBMDirectExpr e) { return e; });
|
||||
if (constant != 0)
|
||||
return SDBMSumExpr::get(lhs.cast<SDBMTermExpr>(),
|
||||
SDBMConstantExpr::get(lhs.getDialect(), constant));
|
||||
return lhs;
|
||||
}
|
||||
|
||||
// Build a difference expression given a direct expression and a negation
|
||||
// expression.
|
||||
static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) {
|
||||
// Fold (x + C) - (x + D) = C - D.
|
||||
if (lhs.getTerm() == rhs.getVar().getTerm())
|
||||
return SDBMConstantExpr::get(
|
||||
lhs.getDialect(), lhs.getConstant() - rhs.getVar().getConstant());
|
||||
|
||||
return SDBMDiffExpr::get(
|
||||
addConstantAndSink<SDBMDirectExpr>(lhs, -rhs.getVar().getConstant(),
|
||||
/*negated=*/false,
|
||||
[](SDBMDirectExpr e) { return e; }),
|
||||
rhs.getVar().getTerm());
|
||||
}
|
||||
|
||||
// Try folding an expression (lhs + rhs) where at least one of the operands
|
||||
// contains a negated variable, i.e. is a negation or a difference expression.
|
||||
static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) {
|
||||
// If exactly one of LHS, RHS is a negation expression, we can construct
|
||||
// a difference expression, which is a special kind in SDBM.
|
||||
auto lhsDirect = lhs.dyn_cast<SDBMDirectExpr>();
|
||||
auto rhsDirect = rhs.dyn_cast<SDBMDirectExpr>();
|
||||
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
|
||||
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
|
||||
|
||||
if (lhsDirect && rhsNeg)
|
||||
return buildDiffExpr(lhsDirect, rhsNeg);
|
||||
if (lhsNeg && rhsDirect)
|
||||
return buildDiffExpr(rhsDirect, lhsNeg);
|
||||
|
||||
// If a subexpression appears in a diff expression on the LHS(RHS) of a
|
||||
// sum expression where it also appears on the RHS(LHS) with the opposite
|
||||
// sign, we can simplify it away and obtain the SDBM form.
|
||||
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
|
||||
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
|
||||
|
||||
// -(x + A) + ((x + B) - y) = -(y + (A - B))
|
||||
if (lhsNeg && rhsDiff &&
|
||||
lhsNeg.getVar().getTerm() == rhsDiff.getLHS().getTerm()) {
|
||||
int64_t constant =
|
||||
lhsNeg.getVar().getConstant() - rhsDiff.getLHS().getConstant();
|
||||
// RHS of the diff is a term expression, its sum with a constant is a direct
|
||||
// expression.
|
||||
return SDBMNegExpr::get(
|
||||
addConstant(rhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
|
||||
}
|
||||
|
||||
// (x + A) + ((y + B) - x) = (y + B) + A.
|
||||
if (lhsDirect && rhsDiff && lhsDirect.getTerm() == rhsDiff.getRHS())
|
||||
return addConstant(rhsDiff.getLHS(), lhsDirect.getConstant());
|
||||
|
||||
// ((x + A) - y) + (-(x + B)) = -(y + (B - A)).
|
||||
if (lhsDiff && rhsNeg &&
|
||||
lhsDiff.getLHS().getTerm() == rhsNeg.getVar().getTerm()) {
|
||||
int64_t constant =
|
||||
rhsNeg.getVar().getConstant() - lhsDiff.getLHS().getConstant();
|
||||
// RHS of the diff is a term expression, its sum with a constant is a direct
|
||||
// expression.
|
||||
return SDBMNegExpr::get(
|
||||
addConstant(lhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
|
||||
}
|
||||
|
||||
// ((x + A) - y) + (y + B) = (x + A) + B.
|
||||
if (rhsDirect && lhsDiff && rhsDirect.getTerm() == lhsDiff.getRHS())
|
||||
return addConstant(lhsDiff.getLHS(), rhsDirect.getConstant());
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
||||
struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
|
||||
SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
|
||||
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
||||
if (!lhs || !rhs)
|
||||
return {};
|
||||
|
||||
// In a "add" AffineExpr, the constant always appears on the right. If
|
||||
// there were two constants, they would have been folded away.
|
||||
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
|
||||
|
||||
// If RHS is a constant, we can always extend the SDBM expression to
|
||||
// include it by sinking the constant into the nearest sum expression.
|
||||
if (auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>()) {
|
||||
int64_t constant = rhsConstant.getValue();
|
||||
auto varying = lhs.dyn_cast<SDBMVaryingExpr>();
|
||||
assert(varying && "unexpected uncanonicalized sum of constants");
|
||||
return addConstant(varying, constant);
|
||||
}
|
||||
|
||||
// Try building a difference expression if one of the values is negated,
|
||||
// or check if a difference on either hand side cancels out the outer term
|
||||
// so as to remain correct within SDBM. Return null otherwise.
|
||||
return foldSumDiff(lhs, rhs);
|
||||
}
|
||||
|
||||
SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
|
||||
// Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
|
||||
AffineExprMatcher x, C;
|
||||
AffineExprMatcher pattern = (x.floorDiv(C)) * C;
|
||||
if (pattern.match(expr)) {
|
||||
if (SDBMExpr converted = visit(x.matched())) {
|
||||
if (auto varConverted = converted.dyn_cast<SDBMTermExpr>())
|
||||
// TODO: return varConverted.stripe(C.getConstantValue());
|
||||
return SDBMStripeExpr::get(
|
||||
varConverted,
|
||||
SDBMConstantExpr::get(dialect,
|
||||
C.getMatchedConstantValue().getValue()));
|
||||
}
|
||||
}
|
||||
|
||||
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
||||
if (!lhs || !rhs)
|
||||
return {};
|
||||
|
||||
// In a "mul" AffineExpr, the constant always appears on the right. If
|
||||
// there were two constants, they would have been folded away.
|
||||
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
|
||||
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
||||
if (!rhsConstant)
|
||||
return {};
|
||||
|
||||
// The only supported "multiplication" expression is an SDBM is dimension
|
||||
// negation, that is a product of dimension and constant -1.
|
||||
if (rhsConstant.getValue() != -1)
|
||||
return {};
|
||||
|
||||
if (auto lhsVar = lhs.dyn_cast<SDBMTermExpr>())
|
||||
return SDBMNegExpr::get(lhsVar);
|
||||
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
|
||||
return SDBMNegator().visitDiff(lhsDiff);
|
||||
|
||||
// Other multiplications are not allowed in SDBM.
|
||||
return {};
|
||||
}
|
||||
|
||||
SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
|
||||
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
||||
if (!lhs || !rhs)
|
||||
return {};
|
||||
|
||||
// 'mod' can only be converted to SDBM if its LHS is a direct expression
|
||||
// and its RHS is a constant. Then it `x mod c = x - x stripe c`.
|
||||
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
||||
auto lhsVar = lhs.dyn_cast<SDBMDirectExpr>();
|
||||
if (!lhsVar || !rhsConstant)
|
||||
return {};
|
||||
return SDBMDiffExpr::get(lhsVar,
|
||||
SDBMStripeExpr::get(lhsVar, rhsConstant));
|
||||
}
|
||||
|
||||
// `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
|
||||
SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
|
||||
SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
|
||||
|
||||
// Dimensions, symbols and constants are converted trivially.
|
||||
SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
|
||||
return SDBMConstantExpr::get(dialect, expr.getValue());
|
||||
}
|
||||
SDBMExpr visitDimExpr(AffineDimExpr expr) {
|
||||
return SDBMDimExpr::get(dialect, expr.getPosition());
|
||||
}
|
||||
SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
|
||||
return SDBMSymbolExpr::get(dialect, expr.getPosition());
|
||||
}
|
||||
|
||||
SDBMDialect *dialect;
|
||||
} converter;
|
||||
converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>();
|
||||
|
||||
if (auto result = converter.visit(affine))
|
||||
return result;
|
||||
return None;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SDBMDiffExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
|
||||
assert(lhs && "expected SDBM dimension");
|
||||
assert(rhs && "expected SDBM dimension");
|
||||
|
||||
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
|
||||
return uniquer.get<detail::SDBMDiffExprStorage>(/*initFn=*/{}, lhs, rhs);
|
||||
}
|
||||
|
||||
SDBMDirectExpr SDBMDiffExpr::getLHS() const {
|
||||
return static_cast<ImplType *>(impl)->lhs;
|
||||
}
|
||||
|
||||
SDBMTermExpr SDBMDiffExpr::getRHS() const {
|
||||
return static_cast<ImplType *>(impl)->rhs;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SDBMDirectExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMTermExpr SDBMDirectExpr::getTerm() {
|
||||
if (auto sum = dyn_cast<SDBMSumExpr>())
|
||||
return sum.getLHS();
|
||||
return cast<SDBMTermExpr>();
|
||||
}
|
||||
|
||||
int64_t SDBMDirectExpr::getConstant() {
|
||||
if (auto sum = dyn_cast<SDBMSumExpr>())
|
||||
return sum.getRHS().getValue();
|
||||
return 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SDBMStripeExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
|
||||
SDBMConstantExpr stripeFactor) {
|
||||
assert(var && "expected SDBM variable expression");
|
||||
assert(stripeFactor && "expected non-null stripe factor");
|
||||
if (stripeFactor.getValue() <= 0)
|
||||
llvm::report_fatal_error("non-positive stripe factor");
|
||||
|
||||
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
||||
return uniquer.get<detail::SDBMBinaryExprStorage>(
|
||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
|
||||
stripeFactor);
|
||||
}
|
||||
|
||||
SDBMDirectExpr SDBMStripeExpr::getLHS() const {
|
||||
if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
|
||||
return lhs.cast<SDBMDirectExpr>();
|
||||
return {};
|
||||
}
|
||||
|
||||
SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
|
||||
return static_cast<ImplType *>(impl)->rhs;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SDBMInputExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
unsigned SDBMInputExpr::getPosition() const {
|
||||
return static_cast<ImplType *>(impl)->position;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SDBMDimExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
|
||||
assert(dialect && "expected non-null dialect");
|
||||
|
||||
auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
|
||||
storage->dialect = dialect;
|
||||
};
|
||||
|
||||
StorageUniquer &uniquer = dialect->getUniquer();
|
||||
return uniquer.get<detail::SDBMTermExprStorage>(
|
||||
assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SDBMSymbolExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
|
||||
assert(dialect && "expected non-null dialect");
|
||||
|
||||
auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
|
||||
storage->dialect = dialect;
|
||||
};
|
||||
|
||||
StorageUniquer &uniquer = dialect->getUniquer();
|
||||
return uniquer.get<detail::SDBMTermExprStorage>(
|
||||
assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SDBMConstantExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
|
||||
assert(dialect && "expected non-null dialect");
|
||||
|
||||
auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) {
|
||||
storage->dialect = dialect;
|
||||
};
|
||||
|
||||
StorageUniquer &uniquer = dialect->getUniquer();
|
||||
return uniquer.get<detail::SDBMConstantExprStorage>(assignCtx, value);
|
||||
}
|
||||
|
||||
int64_t SDBMConstantExpr::getValue() const {
|
||||
return static_cast<ImplType *>(impl)->constant;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SDBMNegExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
|
||||
assert(var && "expected non-null SDBM direct expression");
|
||||
|
||||
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
||||
return uniquer.get<detail::SDBMNegExprStorage>(/*initFn=*/{}, var);
|
||||
}
|
||||
|
||||
SDBMDirectExpr SDBMNegExpr::getVar() const {
|
||||
return static_cast<ImplType *>(impl)->expr;
|
||||
}
|
||||
|
||||
SDBMExpr mlir::ops_assertions::operator+(SDBMExpr lhs, SDBMExpr rhs) {
|
||||
if (auto folded = foldSumDiff(lhs, rhs))
|
||||
return folded;
|
||||
assert(!(lhs.isa<SDBMNegExpr>() && rhs.isa<SDBMNegExpr>()) &&
|
||||
"a sum of negated expressions is a negation of a sum of variables and "
|
||||
"not a correct SDBM");
|
||||
|
||||
// Fold (x - y) + (y - x) = 0.
|
||||
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
|
||||
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
|
||||
if (lhsDiff && rhsDiff) {
|
||||
if (lhsDiff.getLHS() == rhsDiff.getRHS() &&
|
||||
lhsDiff.getRHS() == rhsDiff.getLHS())
|
||||
return SDBMConstantExpr::get(lhs.getDialect(), 0);
|
||||
}
|
||||
|
||||
// If LHS is a constant and RHS is not, swap the order to get into a supported
|
||||
// sum case. From now on, RHS must be a constant.
|
||||
auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
|
||||
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
||||
if (!rhsConstant && lhsConstant) {
|
||||
std::swap(lhs, rhs);
|
||||
std::swap(lhsConstant, rhsConstant);
|
||||
}
|
||||
assert(rhsConstant && "at least one operand must be a constant");
|
||||
|
||||
// Constant-fold if LHS is also a constant.
|
||||
if (lhsConstant)
|
||||
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
|
||||
rhsConstant.getValue());
|
||||
return addConstant(lhs.cast<SDBMVaryingExpr>(), rhsConstant.getValue());
|
||||
}
|
||||
|
||||
SDBMExpr mlir::ops_assertions::operator-(SDBMExpr lhs, SDBMExpr rhs) {
|
||||
// Fold x - x == 0.
|
||||
if (lhs == rhs)
|
||||
return SDBMConstantExpr::get(lhs.getDialect(), 0);
|
||||
|
||||
// LHS and RHS may be constants.
|
||||
auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
|
||||
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
||||
|
||||
// Constant fold if both LHS and RHS are constants.
|
||||
if (lhsConstant && rhsConstant)
|
||||
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() -
|
||||
rhsConstant.getValue());
|
||||
|
||||
// Replace a difference with a sum with a negated value if one of LHS and RHS
|
||||
// is a constant:
|
||||
// x - C == x + (-C);
|
||||
// C - x == -x + C.
|
||||
// This calls into operator+ for further simplification.
|
||||
if (rhsConstant)
|
||||
return lhs + (-rhsConstant);
|
||||
if (lhsConstant)
|
||||
return -rhs + lhsConstant;
|
||||
|
||||
return buildDiffExpr(lhs.cast<SDBMDirectExpr>(), (-rhs).cast<SDBMNegExpr>());
|
||||
}
|
||||
|
||||
SDBMExpr mlir::ops_assertions::stripe(SDBMExpr expr, SDBMExpr factor) {
|
||||
auto constantFactor = factor.cast<SDBMConstantExpr>();
|
||||
assert(constantFactor.getValue() > 0 && "non-positive stripe");
|
||||
|
||||
// Fold x # 1 = x.
|
||||
if (constantFactor.getValue() == 1)
|
||||
return expr;
|
||||
|
||||
return SDBMStripeExpr::get(expr.cast<SDBMDirectExpr>(), constantFactor);
|
||||
}
|
|
@ -1,137 +0,0 @@
|
|||
//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This holds implementation details of SDBMExpr, in particular underlying
|
||||
// storage types.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_SDBMEXPRDETAIL_H
|
||||
#define MLIR_IR_SDBMEXPRDETAIL_H
|
||||
|
||||
#include "mlir/Dialect/SDBM/SDBMExpr.h"
|
||||
#include "mlir/Support/StorageUniquer.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class SDBMDialect;
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Base storage class for SDBMExpr.
|
||||
struct SDBMExprStorage : public StorageUniquer::BaseStorage {
|
||||
SDBMExprKind getKind() { return kind; }
|
||||
|
||||
SDBMDialect *dialect;
|
||||
SDBMExprKind kind;
|
||||
};
|
||||
|
||||
// Storage class for SDBM sum and stripe expressions.
|
||||
struct SDBMBinaryExprStorage : public SDBMExprStorage {
|
||||
using KeyTy = std::tuple<unsigned, SDBMDirectExpr, SDBMConstantExpr>;
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return static_cast<SDBMExprKind>(std::get<0>(key)) == kind &&
|
||||
std::get<1>(key) == lhs && std::get<2>(key) == rhs;
|
||||
}
|
||||
|
||||
static SDBMBinaryExprStorage *
|
||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<SDBMBinaryExprStorage>();
|
||||
result->lhs = std::get<1>(key);
|
||||
result->rhs = std::get<2>(key);
|
||||
result->dialect = result->lhs.getDialect();
|
||||
result->kind = static_cast<SDBMExprKind>(std::get<0>(key));
|
||||
return result;
|
||||
}
|
||||
|
||||
SDBMDirectExpr lhs;
|
||||
SDBMConstantExpr rhs;
|
||||
};
|
||||
|
||||
// Storage class for SDBM difference expressions.
|
||||
struct SDBMDiffExprStorage : public SDBMExprStorage {
|
||||
using KeyTy = std::pair<SDBMDirectExpr, SDBMTermExpr>;
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
|
||||
}
|
||||
|
||||
static SDBMDiffExprStorage *
|
||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<SDBMDiffExprStorage>();
|
||||
result->lhs = std::get<0>(key);
|
||||
result->rhs = std::get<1>(key);
|
||||
result->dialect = result->lhs.getDialect();
|
||||
result->kind = SDBMExprKind::Diff;
|
||||
return result;
|
||||
}
|
||||
|
||||
SDBMDirectExpr lhs;
|
||||
SDBMTermExpr rhs;
|
||||
};
|
||||
|
||||
// Storage class for SDBM constant expressions.
|
||||
struct SDBMConstantExprStorage : public SDBMExprStorage {
|
||||
using KeyTy = int64_t;
|
||||
|
||||
bool operator==(const KeyTy &key) const { return constant == key; }
|
||||
|
||||
static SDBMConstantExprStorage *
|
||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<SDBMConstantExprStorage>();
|
||||
result->constant = key;
|
||||
result->kind = SDBMExprKind::Constant;
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t constant;
|
||||
};
|
||||
|
||||
// Storage class for SDBM dimension and symbol expressions.
|
||||
struct SDBMTermExprStorage : public SDBMExprStorage {
|
||||
using KeyTy = std::pair<unsigned, unsigned>;
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return kind == static_cast<SDBMExprKind>(key.first) &&
|
||||
position == key.second;
|
||||
}
|
||||
|
||||
static SDBMTermExprStorage *
|
||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<SDBMTermExprStorage>();
|
||||
result->kind = static_cast<SDBMExprKind>(key.first);
|
||||
result->position = key.second;
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned position;
|
||||
};
|
||||
|
||||
// Storage class for SDBM negation expressions.
|
||||
struct SDBMNegExprStorage : public SDBMExprStorage {
|
||||
using KeyTy = SDBMDirectExpr;
|
||||
|
||||
bool operator==(const KeyTy &key) const { return key == expr; }
|
||||
|
||||
static SDBMNegExprStorage *
|
||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<SDBMNegExprStorage>();
|
||||
result->expr = key;
|
||||
result->dialect = key.getDialect();
|
||||
result->kind = SDBMExprKind::Neg;
|
||||
return result;
|
||||
}
|
||||
|
||||
SDBMDirectExpr expr;
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_SDBMEXPRDETAIL_H
|
|
@ -1,5 +1,4 @@
|
|||
add_subdirectory(CAPI)
|
||||
add_subdirectory(SDBM)
|
||||
add_subdirectory(lib)
|
||||
|
||||
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||
|
@ -75,7 +74,6 @@ set(MLIR_TEST_DEPENDS
|
|||
mlir-lsp-server
|
||||
mlir-opt
|
||||
mlir-reduce
|
||||
mlir-sdbm-api-test
|
||||
mlir-tblgen
|
||||
mlir-translate
|
||||
mlir_runner_utils
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
set(LLVM_LINK_COMPONENTS
|
||||
Core
|
||||
Support
|
||||
)
|
||||
|
||||
add_llvm_executable(mlir-sdbm-api-test
|
||||
sdbm-api-test.cpp
|
||||
)
|
||||
|
||||
llvm_update_compile_flags(mlir-sdbm-api-test)
|
||||
|
||||
target_link_libraries(mlir-sdbm-api-test
|
||||
PRIVATE
|
||||
MLIRIR
|
||||
MLIRSDBM
|
||||
MLIRSupport
|
||||
)
|
||||
|
||||
target_include_directories(mlir-sdbm-api-test PRIVATE ..)
|
|
@ -1 +0,0 @@
|
|||
config.suffixes.add('.cpp')
|
|
@ -1,201 +0,0 @@
|
|||
//===- sdbm-api-test.cpp - Tests for SDBM expression APIs -----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// RUN: mlir-sdbm-api-test | FileCheck %s
|
||||
|
||||
#include "mlir/Dialect/SDBM/SDBM.h"
|
||||
#include "mlir/Dialect/SDBM/SDBMDialect.h"
|
||||
#include "mlir/Dialect/SDBM/SDBMExpr.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "APITest.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
||||
static MLIRContext *ctx() {
|
||||
static thread_local MLIRContext context;
|
||||
static thread_local bool once =
|
||||
(context.getOrLoadDialect<SDBMDialect>(), true);
|
||||
(void)once;
|
||||
return &context;
|
||||
}
|
||||
|
||||
static SDBMDialect *dialect() {
|
||||
static thread_local SDBMDialect *d = nullptr;
|
||||
if (!d) {
|
||||
d = ctx()->getOrLoadDialect<SDBMDialect>();
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); }
|
||||
|
||||
static SDBMExpr symb(unsigned pos) {
|
||||
return SDBMSymbolExpr::get(dialect(), pos);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace mlir::ops_assertions;
|
||||
|
||||
TEST_FUNC(SDBM_SingleConstraint) {
|
||||
// Build an SDBM defined by
|
||||
// d0 - 3 <= 0 <=> d0 <= 3.
|
||||
auto sdbm = SDBM::get(dim(0) - 3, llvm::None);
|
||||
|
||||
// CHECK: cst d0
|
||||
// CHECK-NEXT: cst inf 3
|
||||
// CHECK-NEXT: d0 inf inf
|
||||
sdbm.print(llvm::outs());
|
||||
}
|
||||
|
||||
TEST_FUNC(SDBM_Equality) {
|
||||
// Build an SDBM defined by
|
||||
//
|
||||
// d0 - d1 - 3 = 0
|
||||
// <=> {d0 - d1 - 3 <= 0 and d0 - d1 - 3 >= 0}
|
||||
// <=> {d0 - d1 <= 3 and d1 - d0 <= -3}.
|
||||
auto sdbm = SDBM::get(llvm::None, dim(0) - dim(1) - 3);
|
||||
|
||||
// CHECK: cst d0 d1
|
||||
// CHECK-NEXT: cst inf inf inf
|
||||
// CHECK-NEXT: d0 inf inf -3
|
||||
// CHECK-NEXT: d1 inf 3 inf
|
||||
sdbm.print(llvm::outs());
|
||||
}
|
||||
|
||||
TEST_FUNC(SDBM_TrivialSimplification) {
|
||||
// Build an SDBM defined by
|
||||
//
|
||||
// d0 - 3 <= 0 <=> d0 <= 3
|
||||
// d0 - 5 <= 0 <=> d0 <= 5
|
||||
//
|
||||
// which should get simplified on construction to only the former.
|
||||
auto sdbm = SDBM::get({dim(0) - 3, dim(0) - 5}, llvm::None);
|
||||
|
||||
// CHECK: cst d0
|
||||
// CHECK-NEXT: cst inf 3
|
||||
// CHECK-NEXT: d0 inf inf
|
||||
sdbm.print(llvm::outs());
|
||||
}
|
||||
|
||||
TEST_FUNC(SDBM_StripeInducedIneqs) {
|
||||
// Build an SDBM defined by d1 = d0 # 3, which induces the constraints
|
||||
//
|
||||
// d1 - d0 <= 0
|
||||
// d0 - d1 <= 3 - 1 = 2
|
||||
auto sdbm = SDBM::get(llvm::None, dim(1) - stripe(dim(0), 3));
|
||||
|
||||
// CHECK: cst d0 d1
|
||||
// CHECK-NEXT: cst inf inf inf
|
||||
// CHECK-NEXT: d0 inf inf 0
|
||||
// CHECK-NEXT: d1 inf 2 0
|
||||
// CHECK-NEXT: d1 = d0 # 3
|
||||
sdbm.print(llvm::outs());
|
||||
}
|
||||
|
||||
TEST_FUNC(SDBM_StripeTemporaries) {
|
||||
// Build an SDBM defined by d0 # 3 <= 0, which creates a temporary
|
||||
// t0 = d0 # 3 leading to a constraint t0 <= 0 and the stripe-induced
|
||||
// constraints
|
||||
//
|
||||
// t0 - d0 <= 0
|
||||
// d0 - t0 <= 3 - 1 = 2
|
||||
auto sdbm = SDBM::get(stripe(dim(0), 3), llvm::None);
|
||||
|
||||
// CHECK: cst d0 t0
|
||||
// CHECK-NEXT: cst inf inf 0
|
||||
// CHECK-NEXT: d0 inf inf 0
|
||||
// CHECK-NEXT: t0 inf 2 inf
|
||||
// CHECK-NEXT: t0 = d0 # 3
|
||||
sdbm.print(llvm::outs());
|
||||
}
|
||||
|
||||
TEST_FUNC(SDBM_ElideInducedInequalities) {
|
||||
// Build an SDBM defined by a single stripe equality d0 = s0 # 3 and make sure
|
||||
// the induced inequalities are not present after converting the SDBM back
|
||||
// into lists of expressions.
|
||||
auto sdbm = SDBM::get(llvm::None, {dim(0) - stripe(symb(0), 3)});
|
||||
|
||||
SmallVector<SDBMExpr, 4> eqs, ineqs;
|
||||
sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
|
||||
// CHECK-EMPTY:
|
||||
for (auto ineq : ineqs)
|
||||
ineq.print(llvm::outs() << '\n');
|
||||
llvm::outs() << "\n";
|
||||
|
||||
// CHECK: d0 - s0 # 3
|
||||
// CHECK-EMPTY:
|
||||
for (auto eq : eqs)
|
||||
eq.print(llvm::outs() << '\n');
|
||||
llvm::outs() << "\n\n";
|
||||
}
|
||||
|
||||
TEST_FUNC(SDBM_StripeTightening) {
|
||||
// Build an SDBM defined by
|
||||
//
|
||||
// d0 = s0 # 3 # 5
|
||||
// s0 # 3 # 5 - d1 + 42 = 0
|
||||
// s0 # 3 - d0 <= 2
|
||||
//
|
||||
// where the last inequality is tighter than that induced by the first stripe
|
||||
// equality (s0 # 3 - d0 <= 5 - 1 = 4). Check that the conversion from SDBM
|
||||
// back to the lists of constraints conserves both the stripe equality and the
|
||||
// tighter inequality.
|
||||
auto s = stripe(stripe(symb(0), 3), 5);
|
||||
auto tight = stripe(symb(0), 3) - dim(0) - 2;
|
||||
auto sdbm = SDBM::get({tight}, {s - dim(0), s - dim(1) + 42});
|
||||
|
||||
SmallVector<SDBMExpr, 4> eqs, ineqs;
|
||||
sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
|
||||
// CHECK: s0 # 3 + -2 - d0
|
||||
// CHECK-EMPTY:
|
||||
for (auto ineq : ineqs)
|
||||
ineq.print(llvm::outs() << '\n');
|
||||
llvm::outs() << "\n";
|
||||
|
||||
// CHECK-DAG: d1 + -42 - d0
|
||||
// CHECK-DAG: d0 - s0 # 3 # 5
|
||||
for (auto eq : eqs)
|
||||
eq.print(llvm::outs() << '\n');
|
||||
llvm::outs() << "\n\n";
|
||||
}
|
||||
|
||||
TEST_FUNC(SDBM_StripeTransitive) {
|
||||
// Build an SDBM defined by
|
||||
//
|
||||
// d0 = d1 # 3
|
||||
// d0 = d2 # 7
|
||||
//
|
||||
// where the same dimension is declared equal to two stripe expressions over
|
||||
// different variables. This is practically handled by introducing a
|
||||
// temporary variable for the second stripe expression and adding an equality
|
||||
// constraint between this variable and the original dimension variable.
|
||||
auto sdbm = SDBM::get(
|
||||
llvm::None, {stripe(dim(1), 3) - dim(0), stripe(dim(2), 7) - dim(0)});
|
||||
|
||||
// CHECK: cst d0 d1 d2 t0
|
||||
// CHECK-NEXT: cst inf inf inf inf inf
|
||||
// CHECK-NEXT: d0 inf 0 2 inf 0
|
||||
// CHECK-NEXT: d1 inf 0 inf inf inf
|
||||
// CHECK-NEXT: d2 inf inf inf inf 0
|
||||
// CHECK-NEXT: t0 inf 0 inf 6 inf
|
||||
// CHECK-NEXT: t0 = d2 # 7
|
||||
// CHECK-NEXT: d0 = d1 # 3
|
||||
sdbm.print(llvm::outs());
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
|
||||
int main() {
|
||||
RUN_TESTS();
|
||||
return 0;
|
||||
}
|
|
@ -65,7 +65,6 @@ tools = [
|
|||
'mlir-linalg-ods-gen',
|
||||
'mlir-linalg-ods-yaml-gen',
|
||||
'mlir-reduce',
|
||||
'mlir-sdbm-api-test',
|
||||
]
|
||||
|
||||
# The following tools are optional
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
// CHECK-NEXT: quant
|
||||
// CHECK-NEXT: rocdl
|
||||
// CHECK-NEXT: scf
|
||||
// CHECK-NEXT: sdbm
|
||||
// CHECK-NEXT: shape
|
||||
// CHECK-NEXT: sparse_tensor
|
||||
// CHECK-NEXT: spv
|
||||
|
|
|
@ -11,5 +11,4 @@ add_subdirectory(Interfaces)
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Pass)
|
||||
add_subdirectory(Rewrite)
|
||||
add_subdirectory(SDBM)
|
||||
add_subdirectory(TableGen)
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
add_mlir_unittest(MLIRSDBMTests
|
||||
SDBMTest.cpp
|
||||
)
|
||||
target_link_libraries(MLIRSDBMTests
|
||||
PRIVATE
|
||||
MLIRSDBM
|
||||
)
|
|
@ -1,449 +0,0 @@
|
|||
//===- SDBMTest.cpp - SDBM expression unit tests --------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SDBM/SDBM.h"
|
||||
#include "mlir/Dialect/SDBM/SDBMDialect.h"
|
||||
#include "mlir/Dialect/SDBM/SDBMExpr.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
||||
static MLIRContext *ctx() {
|
||||
static thread_local MLIRContext context;
|
||||
context.getOrLoadDialect<SDBMDialect>();
|
||||
return &context;
|
||||
}
|
||||
|
||||
static SDBMDialect *dialect() {
|
||||
static thread_local SDBMDialect *d = nullptr;
|
||||
if (!d) {
|
||||
d = ctx()->getOrLoadDialect<SDBMDialect>();
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); }
|
||||
|
||||
static SDBMExpr symb(unsigned pos) {
|
||||
return SDBMSymbolExpr::get(dialect(), pos);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace mlir::ops_assertions;
|
||||
|
||||
TEST(SDBMOperators, Add) {
|
||||
auto expr = dim(0) + 42;
|
||||
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
|
||||
ASSERT_TRUE(sumExpr);
|
||||
EXPECT_EQ(sumExpr.getLHS(), dim(0));
|
||||
EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
|
||||
}
|
||||
|
||||
TEST(SDBMOperators, AddFolding) {
|
||||
auto constant = SDBMConstantExpr::get(dialect(), 2) + 42;
|
||||
auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
|
||||
ASSERT_TRUE(constantExpr);
|
||||
EXPECT_EQ(constantExpr.getValue(), 44);
|
||||
|
||||
auto expr = (dim(0) + 10) + 32;
|
||||
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
|
||||
ASSERT_TRUE(sumExpr);
|
||||
EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
|
||||
|
||||
expr = dim(0) + SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1));
|
||||
auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();
|
||||
ASSERT_TRUE(diffExpr);
|
||||
EXPECT_EQ(diffExpr.getLHS(), dim(0));
|
||||
EXPECT_EQ(diffExpr.getRHS(), dim(1));
|
||||
|
||||
auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0);
|
||||
EXPECT_EQ(inverted, expr);
|
||||
|
||||
// Check that opposite values cancel each other, and that we elide the zero
|
||||
// constant.
|
||||
expr = dim(0) + 42;
|
||||
auto onlyDim = expr - 42;
|
||||
EXPECT_EQ(onlyDim, dim(0));
|
||||
|
||||
// Check that we can sink a constant under a negation.
|
||||
expr = -(dim(0) + 2);
|
||||
auto negatedSum = (expr + 10).dyn_cast<SDBMNegExpr>();
|
||||
ASSERT_TRUE(negatedSum);
|
||||
auto sum = negatedSum.getVar().dyn_cast<SDBMSumExpr>();
|
||||
ASSERT_TRUE(sum);
|
||||
EXPECT_EQ(sum.getRHS().getValue(), -8);
|
||||
|
||||
// Sum with zero is the same as the original expression.
|
||||
EXPECT_EQ(dim(0) + 0, dim(0));
|
||||
|
||||
// Sum of opposite differences is zero.
|
||||
auto diffOfDiffs =
|
||||
((dim(0) - dim(1)) + (dim(1) - dim(0))).dyn_cast<SDBMConstantExpr>();
|
||||
EXPECT_EQ(diffOfDiffs.getValue(), 0);
|
||||
}
|
||||
|
||||
TEST(SDBMOperators, AddNegativeTerms) {
|
||||
const int64_t A = 7;
|
||||
const int64_t B = -5;
|
||||
auto x = SDBMDimExpr::get(dialect(), 0);
|
||||
auto y = SDBMDimExpr::get(dialect(), 1);
|
||||
|
||||
// Check the simplification patterns in addition where one of the variables is
|
||||
// cancelled out and the result remains an SDBM.
|
||||
EXPECT_EQ(-(x + A) + ((x + B) - y), -(y + (A - B)));
|
||||
EXPECT_EQ((x + A) + ((y + B) - x), (y + B) + A);
|
||||
EXPECT_EQ(((x + A) - y) + (-(x + B)), -(y + (B - A)));
|
||||
EXPECT_EQ(((x + A) - y) + (y + B), (x + A) + B);
|
||||
}
|
||||
|
||||
TEST(SDBMOperators, Diff) {
|
||||
auto expr = dim(0) - dim(1);
|
||||
auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();
|
||||
ASSERT_TRUE(diffExpr);
|
||||
EXPECT_EQ(diffExpr.getLHS(), dim(0));
|
||||
EXPECT_EQ(diffExpr.getRHS(), dim(1));
|
||||
}
|
||||
|
||||
TEST(SDBMOperators, DiffFolding) {
|
||||
auto constant = SDBMConstantExpr::get(dialect(), 10) - 3;
|
||||
auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
|
||||
ASSERT_TRUE(constantExpr);
|
||||
EXPECT_EQ(constantExpr.getValue(), 7);
|
||||
|
||||
auto expr = dim(0) - 3;
|
||||
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
|
||||
ASSERT_TRUE(sumExpr);
|
||||
EXPECT_EQ(sumExpr.getRHS().getValue(), -3);
|
||||
|
||||
auto zero = dim(0) - dim(0);
|
||||
constantExpr = zero.dyn_cast<SDBMConstantExpr>();
|
||||
ASSERT_TRUE(constantExpr);
|
||||
EXPECT_EQ(constantExpr.getValue(), 0);
|
||||
|
||||
// Check that the constant terms in difference-of-sums are folded.
|
||||
// (d0 - 3) - (d1 - 5) = (d0 + 2) - d1
|
||||
auto diffOfSums = ((dim(0) - 3) - (dim(1) - 5)).dyn_cast<SDBMDiffExpr>();
|
||||
ASSERT_TRUE(diffOfSums);
|
||||
auto lhs = diffOfSums.getLHS().dyn_cast<SDBMSumExpr>();
|
||||
ASSERT_TRUE(lhs);
|
||||
EXPECT_EQ(lhs.getLHS(), dim(0));
|
||||
EXPECT_EQ(lhs.getRHS().getValue(), 2);
|
||||
EXPECT_EQ(diffOfSums.getRHS(), dim(1));
|
||||
|
||||
// Check that identical dimensions with opposite signs cancel each other.
|
||||
auto cstOnly = ((dim(0) + 42) - dim(0)).dyn_cast<SDBMConstantExpr>();
|
||||
ASSERT_TRUE(cstOnly);
|
||||
EXPECT_EQ(cstOnly.getValue(), 42);
|
||||
|
||||
// Check that identical terms in sum of diffs cancel out.
|
||||
auto dimOnly = (-dim(0) + (dim(0) - dim(1)));
|
||||
EXPECT_EQ(dimOnly, -dim(1));
|
||||
dimOnly = (dim(0) - dim(1)) + (-dim(0));
|
||||
EXPECT_EQ(dimOnly, -dim(1));
|
||||
dimOnly = (dim(0) - dim(1)) + dim(1);
|
||||
EXPECT_EQ(dimOnly, dim(0));
|
||||
dimOnly = dim(0) + (dim(1) - dim(0));
|
||||
EXPECT_EQ(dimOnly, dim(1));
|
||||
|
||||
// Top-level zero constant is fine.
|
||||
cstOnly = (-symb(1) + symb(1)).dyn_cast<SDBMConstantExpr>();
|
||||
ASSERT_TRUE(cstOnly);
|
||||
EXPECT_EQ(cstOnly.getValue(), 0);
|
||||
}
|
||||
|
||||
TEST(SDBMOperators, Negate) {
|
||||
auto sum = dim(0) + 3;
|
||||
auto negated = (-sum).dyn_cast<SDBMNegExpr>();
|
||||
ASSERT_TRUE(negated);
|
||||
EXPECT_EQ(negated.getVar(), sum);
|
||||
}
|
||||
|
||||
TEST(SDBMOperators, Stripe) {
|
||||
auto expr = stripe(dim(0), 3);
|
||||
auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>();
|
||||
ASSERT_TRUE(stripeExpr);
|
||||
EXPECT_EQ(stripeExpr.getLHS(), dim(0));
|
||||
EXPECT_EQ(stripeExpr.getStripeFactor().getValue(), 3);
|
||||
}
|
||||
|
||||
TEST(SDBM, RoundTripEqs) {
|
||||
// Build an SDBM defined by
|
||||
//
|
||||
// d0 = s0 # 3 # 5
|
||||
// s0 # 3 # 5 - d1 + 42 = 0
|
||||
//
|
||||
// and perform a double round-trip between the "list of equalities" and SDBM
|
||||
// representation. After the first round-trip, the equalities may be
|
||||
// different due to simplification or equivalent substitutions (e.g., the
|
||||
// second equality may become d0 - d1 + 42 = 0). However, there should not
|
||||
// be any further simplification after the second round-trip,
|
||||
|
||||
// Build the SDBM from a pair of equalities and extract back the lists of
|
||||
// inequalities and equalities. Check that all equalities are properly
|
||||
// detected and none of them decayed into inequalities.
|
||||
auto s = stripe(stripe(symb(0), 3), 5);
|
||||
auto sdbm = SDBM::get(llvm::None, {s - dim(0), s - dim(1) + 42});
|
||||
SmallVector<SDBMExpr, 4> eqs, ineqs;
|
||||
sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
|
||||
ASSERT_TRUE(ineqs.empty());
|
||||
|
||||
// Do the second round-trip.
|
||||
auto sdbm2 = SDBM::get(llvm::None, eqs);
|
||||
SmallVector<SDBMExpr, 4> eqs2, ineqs2;
|
||||
sdbm2.getSDBMExpressions(dialect(), ineqs2, eqs2);
|
||||
ASSERT_EQ(eqs.size(), eqs2.size());
|
||||
|
||||
// Check that the sets of equalities are equal, their order is not relevant.
|
||||
llvm::DenseSet<SDBMExpr> eqSet, eq2Set;
|
||||
eqSet.insert(eqs.begin(), eqs.end());
|
||||
eq2Set.insert(eqs2.begin(), eqs2.end());
|
||||
EXPECT_EQ(eqSet, eq2Set);
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, Constant) {
|
||||
// We can create constants and query them.
|
||||
auto expr = SDBMConstantExpr::get(dialect(), 42);
|
||||
EXPECT_EQ(expr.getValue(), 42);
|
||||
|
||||
// Two separately created constants with identical values are trivially equal.
|
||||
auto expr2 = SDBMConstantExpr::get(dialect(), 42);
|
||||
EXPECT_EQ(expr, expr2);
|
||||
|
||||
// Hierarchy is okay.
|
||||
auto generic = static_cast<SDBMExpr>(expr);
|
||||
EXPECT_TRUE(generic.isa<SDBMConstantExpr>());
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, Dim) {
|
||||
// We can create dimension expressions and query them.
|
||||
auto expr = SDBMDimExpr::get(dialect(), 0);
|
||||
EXPECT_EQ(expr.getPosition(), 0u);
|
||||
|
||||
// Two separately created dimensions with the same position are trivially
|
||||
// equal.
|
||||
auto expr2 = SDBMDimExpr::get(dialect(), 0);
|
||||
EXPECT_EQ(expr, expr2);
|
||||
|
||||
// Hierarchy is okay.
|
||||
auto generic = static_cast<SDBMExpr>(expr);
|
||||
EXPECT_TRUE(generic.isa<SDBMDimExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMInputExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
|
||||
|
||||
// Dimensions are not Symbols.
|
||||
auto symbol = SDBMSymbolExpr::get(dialect(), 0);
|
||||
EXPECT_NE(expr, symbol);
|
||||
EXPECT_FALSE(expr.isa<SDBMSymbolExpr>());
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, Symbol) {
|
||||
// We can create symbol expressions and query them.
|
||||
auto expr = SDBMSymbolExpr::get(dialect(), 0);
|
||||
EXPECT_EQ(expr.getPosition(), 0u);
|
||||
|
||||
// Two separately created symbols with the same position are trivially equal.
|
||||
auto expr2 = SDBMSymbolExpr::get(dialect(), 0);
|
||||
EXPECT_EQ(expr, expr2);
|
||||
|
||||
// Hierarchy is okay.
|
||||
auto generic = static_cast<SDBMExpr>(expr);
|
||||
EXPECT_TRUE(generic.isa<SDBMSymbolExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMInputExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
|
||||
|
||||
// Dimensions are not Symbols.
|
||||
auto symbol = SDBMDimExpr::get(dialect(), 0);
|
||||
EXPECT_NE(expr, symbol);
|
||||
EXPECT_FALSE(expr.isa<SDBMDimExpr>());
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, Stripe) {
|
||||
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
|
||||
auto cst0 = SDBMConstantExpr::get(dialect(), 0);
|
||||
auto var = SDBMSymbolExpr::get(dialect(), 0);
|
||||
|
||||
// We can create stripe expressions and query them.
|
||||
auto expr = SDBMStripeExpr::get(var, cst2);
|
||||
EXPECT_EQ(expr.getLHS(), var);
|
||||
EXPECT_EQ(expr.getStripeFactor(), cst2);
|
||||
|
||||
// Two separately created stripe expressions with the same LHS and RHS are
|
||||
// trivially equal.
|
||||
auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(dialect(), 0), cst2);
|
||||
EXPECT_EQ(expr, expr2);
|
||||
|
||||
// Stripes can be nested.
|
||||
SDBMStripeExpr::get(expr, SDBMConstantExpr::get(dialect(), 4));
|
||||
|
||||
// Non-positive stripe factors are not allowed.
|
||||
EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive");
|
||||
|
||||
// Stripes can have sums on the LHS.
|
||||
SDBMStripeExpr::get(SDBMSumExpr::get(var, cst2), cst2);
|
||||
|
||||
// Hierarchy is okay.
|
||||
auto generic = static_cast<SDBMExpr>(expr);
|
||||
EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, Neg) {
|
||||
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
|
||||
auto var = SDBMSymbolExpr::get(dialect(), 0);
|
||||
auto stripe = SDBMStripeExpr::get(var, cst2);
|
||||
|
||||
// We can create negation expressions and query them.
|
||||
auto expr = SDBMNegExpr::get(var);
|
||||
EXPECT_EQ(expr.getVar(), var);
|
||||
auto expr2 = SDBMNegExpr::get(stripe);
|
||||
EXPECT_EQ(expr2.getVar(), stripe);
|
||||
|
||||
// Neg expressions are trivially comparable.
|
||||
EXPECT_EQ(expr, SDBMNegExpr::get(var));
|
||||
|
||||
// Hierarchy is okay.
|
||||
auto generic = static_cast<SDBMExpr>(expr);
|
||||
EXPECT_TRUE(generic.isa<SDBMNegExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, Sum) {
|
||||
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
|
||||
auto var = SDBMSymbolExpr::get(dialect(), 0);
|
||||
auto stripe = SDBMStripeExpr::get(var, cst2);
|
||||
|
||||
// We can create sum expressions and query them.
|
||||
auto expr = SDBMSumExpr::get(var, cst2);
|
||||
EXPECT_EQ(expr.getLHS(), var);
|
||||
EXPECT_EQ(expr.getRHS(), cst2);
|
||||
auto expr2 = SDBMSumExpr::get(stripe, cst2);
|
||||
EXPECT_EQ(expr2.getLHS(), stripe);
|
||||
EXPECT_EQ(expr2.getRHS(), cst2);
|
||||
|
||||
// Sum expressions are trivially comparable.
|
||||
EXPECT_EQ(expr, SDBMSumExpr::get(var, cst2));
|
||||
|
||||
// Hierarchy is okay.
|
||||
auto generic = static_cast<SDBMExpr>(expr);
|
||||
EXPECT_TRUE(generic.isa<SDBMSumExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, Diff) {
|
||||
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
|
||||
auto var = SDBMSymbolExpr::get(dialect(), 0);
|
||||
auto stripe = SDBMStripeExpr::get(var, cst2);
|
||||
|
||||
// We can create sum expressions and query them.
|
||||
auto expr = SDBMDiffExpr::get(var, stripe);
|
||||
EXPECT_EQ(expr.getLHS(), var);
|
||||
EXPECT_EQ(expr.getRHS(), stripe);
|
||||
auto expr2 = SDBMDiffExpr::get(stripe, var);
|
||||
EXPECT_EQ(expr2.getLHS(), stripe);
|
||||
EXPECT_EQ(expr2.getRHS(), var);
|
||||
|
||||
// Sum expressions are trivially comparable.
|
||||
EXPECT_EQ(expr, SDBMDiffExpr::get(var, stripe));
|
||||
|
||||
// Hierarchy is okay.
|
||||
auto generic = static_cast<SDBMExpr>(expr);
|
||||
EXPECT_TRUE(generic.isa<SDBMDiffExpr>());
|
||||
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, AffineRoundTrip) {
|
||||
// Build an expression (s0 - s0 # 2)
|
||||
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
|
||||
auto var = SDBMSymbolExpr::get(dialect(), 0);
|
||||
auto stripe = SDBMStripeExpr::get(var, cst2);
|
||||
auto expr = SDBMDiffExpr::get(var, stripe);
|
||||
|
||||
// Check that it can be converted to AffineExpr and back, i.e. stripe
|
||||
// detection works correctly.
|
||||
Optional<SDBMExpr> roundtripped =
|
||||
SDBMExpr::tryConvertAffineExpr(expr.getAsAffineExpr());
|
||||
ASSERT_TRUE(roundtripped.hasValue());
|
||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(expr));
|
||||
|
||||
// Check that (s0 # 2 # 5) can be converted to AffineExpr, i.e. stripe
|
||||
// detection supports nested expressions.
|
||||
auto cst5 = SDBMConstantExpr::get(dialect(), 5);
|
||||
auto outerStripe = SDBMStripeExpr::get(stripe, cst5);
|
||||
roundtripped = SDBMExpr::tryConvertAffineExpr(outerStripe.getAsAffineExpr());
|
||||
ASSERT_TRUE(roundtripped.hasValue());
|
||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(outerStripe));
|
||||
|
||||
// Check that ((s0 + 2) # 5) can be round-tripped through AffineExpr, i.e.
|
||||
// stripe detection supports sum expressions.
|
||||
auto inner = SDBMSumExpr::get(var, cst2);
|
||||
auto stripeSum = SDBMStripeExpr::get(inner, cst5);
|
||||
roundtripped = SDBMExpr::tryConvertAffineExpr(stripeSum.getAsAffineExpr());
|
||||
ASSERT_TRUE(roundtripped.hasValue());
|
||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(stripeSum));
|
||||
|
||||
// Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
|
||||
// deeper expression tree.
|
||||
auto sum = SDBMSumExpr::get(outerStripe, cst2);
|
||||
auto diff = SDBMDiffExpr::get(sum, stripe);
|
||||
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
|
||||
ASSERT_TRUE(roundtripped.hasValue());
|
||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
|
||||
|
||||
// Check a nested stripe-sum combination.
|
||||
auto cst7 = SDBMConstantExpr::get(dialect(), 7);
|
||||
auto nestedStripe =
|
||||
SDBMStripeExpr::get(SDBMSumExpr::get(stripeSum, cst2), cst7);
|
||||
diff = SDBMDiffExpr::get(nestedStripe, stripe);
|
||||
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
|
||||
ASSERT_TRUE(roundtripped.hasValue());
|
||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, MatchStripeMulPattern) {
|
||||
// Make sure conversion from AffineExpr recognizes multiplicative stripe
|
||||
// pattern (x floordiv B) * B == x # B.
|
||||
auto cst = getAffineConstantExpr(42, ctx());
|
||||
auto dim = getAffineDimExpr(0, ctx());
|
||||
auto floor = dim.floorDiv(cst);
|
||||
auto mul = cst * floor;
|
||||
Optional<SDBMExpr> converted = SDBMStripeExpr::tryConvertAffineExpr(mul);
|
||||
ASSERT_TRUE(converted.hasValue());
|
||||
EXPECT_TRUE(converted->isa<SDBMStripeExpr>());
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, NonSDBM) {
|
||||
auto d0 = getAffineDimExpr(0, ctx());
|
||||
auto d1 = getAffineDimExpr(1, ctx());
|
||||
auto sum = d0 + d1;
|
||||
auto c2 = getAffineConstantExpr(2, ctx());
|
||||
auto prod = d0 * c2;
|
||||
auto ceildiv = d1.ceilDiv(c2);
|
||||
|
||||
// The following are not valid SDBM expressions:
|
||||
// - a sum of two variables
|
||||
EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(sum).hasValue());
|
||||
// - a variable with coefficient other than 1 or -1
|
||||
EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(prod).hasValue());
|
||||
// - a ceildiv expression
|
||||
EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(ceildiv).hasValue());
|
||||
}
|
||||
|
||||
} // end namespace
|
Loading…
Reference in New Issue