[RFC][MLIR] Use AffineExprRef in place of AffineExpr* in IR

This CL starts by replacing AffineExpr* with value-type AffineExprRef in a few
places in the IR. By a domino effect that is pretty telling of the
inconsistencies in the codebase, const is removed where it makes sense.

The rationale is that the decision was concisously made that unique'd types
have pointer semantics without const specifier. This is fine but we should be
consistent. In the end, the only logical invariant is that there should never
be such a thing as a const AffineExpr*, const AffineMap* or const IntegerSet*
in our codebase.

This CL takes a number of shortcuts to killing const with fire, in particular
forcing const AffineExprRef to return the underlying non-const
AffineExpr*. This will be removed once AffineExpr* has disappeared in
containers but for now such shortcuts allow a bit of sanity in this long quest
for cleanups.

The **only** places where const AffineExpr*, const AffineMap* or const
IntegerSet* may still appear is by transitive needs from containers,
comparison operators etc.

There is still one major thing remaining here: figure out why cast/dyn_cast
return me a const AffineXXX*, which in turn requires a bunch of ugly
const_casts. I suspect this is due to the classof
taking const AffineXXXExpr*. I wonder whether this is a side effect of 1., if
it is coming from llvm itself (I'd doubt it) or something else (clattner@?)

In light of this, the whole discussion about const makes total sense to me now
and I would systematically apply the rule that in the end, we should never
have any const XXX in our codebase for unique'd types (assuming we can remove
them all in containers and no additional constness constraint is added on us
from the outside world).

PiperOrigin-RevId: 215811554
This commit is contained in:
Nicolas Vasilache 2018-10-04 15:10:33 -07:00 committed by jpienaar
parent 5b8017db18
commit b55b407601
17 changed files with 246 additions and 226 deletions

View File

@ -24,6 +24,7 @@
#define MLIR_IR_AFFINE_EXPR_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/Support/Casting.h"
namespace mlir {
@ -58,26 +59,26 @@ public:
};
/// Return the classification for this type.
Kind getKind() const { return kind; }
Kind getKind() { return kind; }
void print(raw_ostream &os) const;
void dump() const;
void print(raw_ostream &os);
void dump();
/// Returns true if this expression is made out of only symbols and
/// constants, i.e., it does not involve dimensional identifiers.
bool isSymbolicOrConstant() const;
bool isSymbolicOrConstant();
/// Returns true if this is a pure affine expression, i.e., multiplication,
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
bool isPureAffine() const;
bool isPureAffine();
/// Returns the greatest known integral divisor of this affine expression.
uint64_t getLargestKnownDivisor() const;
uint64_t getLargestKnownDivisor();
/// Return true if the affine expression is a multiple of 'factor'.
bool isMultipleOf(int64_t factor) const;
bool isMultipleOf(int64_t factor);
MLIRContext *getContext() const;
MLIRContext *getContext();
protected:
explicit AffineExpr(Kind kind, MLIRContext *context)
@ -93,7 +94,7 @@ private:
MLIRContext *context;
};
inline raw_ostream &operator<<(raw_ostream &os, const AffineExpr &expr) {
inline raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr) {
expr.print(os);
return os;
}
@ -104,25 +105,22 @@ inline raw_ostream &operator<<(raw_ostream &os, const AffineExpr &expr) {
/// away.
/// TODO(ntv): Remove all redundant MLIRContext* arguments through the API
/// TODO(ntv): Remove all uses of AffineExpr* in Parser.cpp
/// TODO(ntv): Remove all uses of AffineExpr* in learning/brain
/// TODO(ntv): Remove all uses of AffineExpr* in IR
/// TODO(ntv): Add extra out-of-class operators for int op AffineExprBaseRef
/// TODO(ntv): Rename
/// TODO(ntv): Drop const everywhere it makes sense in AffineExpr
/// TODO(ntv): remove const comment
/// TODO(ntv): pointer pair
template <typename AffineExprType> class AffineExprBaseRef {
public:
/* implicit */ AffineExprBaseRef(const AffineExprType *expr)
: expr(const_cast<AffineExprType *>(expr)) {}
/* implicit */ AffineExprBaseRef(AffineExprType *expr) : expr(expr) {}
AffineExprBaseRef(const AffineExprBaseRef &other) : expr(other.expr){};
AffineExprBaseRef &operator=(AffineExprBaseRef other) {
expr = other;
return *this;
};
bool operator==(AffineExprBaseRef other) { return expr == other; };
bool operator==(AffineExprBaseRef other) const { return expr == other.expr; };
AffineExprType *operator->() { return expr; }
AffineExprType const *operator->() const { return expr; }
/* implicit */ operator AffineExprType *() { return expr; }
bool operator!() { return expr == nullptr; }
@ -141,16 +139,13 @@ public:
AffineExprBaseRef operator%(uint64_t v) const;
AffineExprBaseRef operator%(AffineExprBaseRef other) const;
friend AffineExprType *
llvm::simplify_type<AffineExprBaseRef<AffineExprType>>::getSimplifiedValue(
AffineExprBaseRef<AffineExprType> &input);
private:
AffineExprType *expr;
};
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
inline ::llvm::hash_code hash_value(AffineExprRef arg);
} // namespace mlir
namespace llvm {
@ -160,7 +155,25 @@ namespace llvm {
template <typename T> struct simplify_type<mlir::AffineExprBaseRef<T>> {
using SimpleType = T *;
static SimpleType getSimplifiedValue(mlir::AffineExprBaseRef<T> &input) {
return input.expr;
return input;
}
};
// AffineExprRef hash just like pointers
template <> struct DenseMapInfo<mlir::AffineExprRef> {
static mlir::AffineExprRef getEmptyKey() {
auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getEmptyKey();
return mlir::AffineExprRef(pointer);
}
static mlir::AffineExprRef getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getTombstoneKey();
return mlir::AffineExprRef(pointer);
}
static unsigned getHashValue(mlir::AffineExprRef val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::AffineExprRef LHS, mlir::AffineExprRef RHS) {
return LHS == RHS;
}
};
@ -168,6 +181,11 @@ template <typename T> struct simplify_type<mlir::AffineExprBaseRef<T>> {
namespace mlir {
// Make AffineExprRef hashable.
inline ::llvm::hash_code hash_value(AffineExprRef arg) {
return ::llvm::hash_value(static_cast<AffineExpr *>(arg));
}
/// Affine binary operation expression. An affine binary operation could be an
/// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
/// represented through a multiply by -1 and add.) These expressions are always
@ -212,34 +230,35 @@ public:
static AffineExprRef getMod(AffineExprRef lhs, uint64_t rhs,
MLIRContext *context);
AffineExprRef getLHS() const { return lhs; }
AffineExprRef getRHS() const { return rhs; }
AffineExprRef getLHS() { return lhs; }
AffineExprRef getRHS() { return rhs; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const AffineExpr *expr) {
return expr->getKind() <= Kind::LAST_AFFINE_BINARY_OP;
return const_cast<AffineExpr *>(expr)->getKind() <=
Kind::LAST_AFFINE_BINARY_OP;
}
protected:
explicit AffineBinaryOpExpr(Kind kind, AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
AffineExpr *const lhs;
AffineExpr *const rhs;
const AffineExprRef lhs;
const AffineExprRef rhs;
private:
~AffineBinaryOpExpr() = delete;
// Simplification prior to construction of binary affine op expressions.
static AffineExpr *simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context);
static AffineExpr *simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context);
static AffineExpr *simplifyFloorDiv(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context);
static AffineExpr *simplifyCeilDiv(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context);
static AffineExpr *simplifyMod(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context);
static AffineExprRef simplifyAdd(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
static AffineExprRef simplifyMul(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
static AffineExprRef simplifyFloorDiv(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
static AffineExprRef simplifyCeilDiv(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
static AffineExprRef simplifyMod(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
};
/// A dimensional identifier appearing in an affine expression.
@ -252,11 +271,11 @@ public:
static AffineExprBaseRef<AffineExpr> get(unsigned position,
MLIRContext *context);
unsigned getPosition() const { return position; }
unsigned getPosition() { return position; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const AffineExpr *expr) {
return expr->getKind() == Kind::DimId;
return const_cast<AffineExpr *>(expr)->getKind() == Kind::DimId;
}
private:
@ -278,11 +297,11 @@ public:
static AffineExprBaseRef<AffineExpr> get(unsigned position,
MLIRContext *context);
unsigned getPosition() const { return position; }
unsigned getPosition() { return position; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const AffineExpr *expr) {
return expr->getKind() == Kind::SymbolId;
return const_cast<AffineExpr *>(expr)->getKind() == Kind::SymbolId;
}
private:
@ -300,11 +319,11 @@ public:
static AffineExprBaseRef<AffineExpr> get(int64_t constant,
MLIRContext *context);
int64_t getValue() const { return constant; }
int64_t getValue() { return constant; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const AffineExpr *expr) {
return expr->getKind() == Kind::Constant;
return const_cast<AffineExpr *>(expr)->getKind() == Kind::Constant;
}
private:

View File

@ -78,7 +78,7 @@ template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
// that you use to visit affine expressions...
public:
// Function to walk an AffineExpr (in post order).
RetTy walkPostOrder(AffineExpr *expr) {
RetTy walkPostOrder(AffineExprRef expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr->getKind()) {
@ -120,7 +120,7 @@ public:
}
// Function to visit an AffineExpr.
RetTy visit(AffineExpr *expr) {
RetTy visit(AffineExprRef expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr->getKind()) {

View File

@ -30,8 +30,10 @@
namespace mlir {
class MLIRContext;
class AffineExpr;
template <typename T> class AffineExprBaseRef;
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
class MLIRContext;
/// A multi-dimensional affine map
/// Affine map's are immutable like Type's, and they are uniqued.
@ -41,8 +43,8 @@ class AffineExpr;
class AffineMap {
public:
static AffineMap *get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> results,
ArrayRef<AffineExpr *> rangeSizes,
ArrayRef<AffineExprRef> results,
ArrayRef<AffineExprRef> rangeSizes,
MLIRContext *context);
/// Returns a single constant result affine map.
@ -51,60 +53,56 @@ public:
/// Returns true if the co-domain (or more loosely speaking, range) of this
/// map is bounded. Bounded affine maps have a size (extent) for each of
/// their range dimensions (more accurately co-domain dimensions).
bool isBounded() const { return rangeSizes != nullptr; }
bool isBounded() { return !rangeSizes.empty(); }
/// Returns true if this affine map is an identity affine map.
/// An identity affine map corresponds to an identity affine function on the
/// dimensional identifiers.
bool isIdentity() const;
bool isIdentity();
/// Returns true if this affine map is a single result constant function.
bool isSingleConstant() const;
bool isSingleConstant();
/// Returns the constant result of this map. This methods asserts that the map
/// has a single constant result.
int64_t getSingleConstantResult() const;
int64_t getSingleConstantResult();
// Prints affine map to 'os'.
void print(raw_ostream &os) const;
void dump() const;
void print(raw_ostream &os);
void dump();
unsigned getNumDims() const { return numDims; }
unsigned getNumSymbols() const { return numSymbols; }
unsigned getNumResults() const { return numResults; }
unsigned getNumInputs() const { return numDims + numSymbols; }
unsigned getNumDims() { return numDims; }
unsigned getNumSymbols() { return numSymbols; }
unsigned getNumResults() { return numResults; }
unsigned getNumInputs() { return numDims + numSymbols; }
ArrayRef<AffineExpr *> getResults() const {
return ArrayRef<AffineExpr *>(results, numResults);
}
ArrayRef<AffineExprRef> getResults() { return results; }
AffineExpr *getResult(unsigned idx) const { return results[idx]; }
AffineExprRef getResult(unsigned idx);
ArrayRef<AffineExpr *> getRangeSizes() const {
return rangeSizes ? ArrayRef<AffineExpr *>(rangeSizes, numResults)
: ArrayRef<AffineExpr *>();
}
ArrayRef<AffineExprRef> getRangeSizes() { return rangeSizes; }
private:
AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
AffineExpr *const *results, AffineExpr *const *rangeSizes);
ArrayRef<AffineExprRef> results,
ArrayRef<AffineExprRef> rangeSizes);
AffineMap(const AffineMap &) = delete;
void operator=(const AffineMap &) = delete;
const unsigned numDims;
const unsigned numSymbols;
const unsigned numResults;
unsigned numDims;
unsigned numSymbols;
unsigned numResults;
/// The affine expressions for this (multi-dimensional) map.
/// TODO: use trailing objects for this.
AffineExpr *const *const results;
ArrayRef<AffineExprRef> results;
/// The extents along each of the range dimensions if the map is bounded,
/// nullptr otherwise.
AffineExpr *const *const rangeSizes;
ArrayRef<AffineExprRef> rangeSizes;
};
} // end namespace mlir
} // end namespace mlir
#endif // MLIR_IR_AFFINE_MAP_H
#endif // MLIR_IR_AFFINE_MAP_H

View File

@ -120,17 +120,9 @@ public:
AffineExprRef getCeilDivExpr(AffineExprRef lhs, AffineExprRef rhs);
AffineExprRef getCeilDivExpr(AffineExprRef lhs, uint64_t rhs);
/// Creates a sum of products affine expression from constant coefficients.
/// If c_0, c_1, ... are the coefficients in the order corresponding to
/// dimensions, symbols, and the constant term, create the affine expression:
/// expr = c_0*d0 + c_1*d1 + ... + c_{ndims-1}*d_{ndims-1} + c_{..}*s0 +
/// c_{..}*s1 + ... + const
AffineExpr *getAddMulPureAffineExpr(unsigned numDims, unsigned numSymbols,
ArrayRef<int64_t> coeffs);
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> results,
ArrayRef<AffineExpr *> rangeSizes);
ArrayRef<AffineExprRef> results,
ArrayRef<AffineExprRef> rangeSizes);
// Special cases of affine maps and integer sets
/// Returns a single constant result affine map with 0 dimensions and 0
@ -153,7 +145,7 @@ public:
// Integer set.
IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> constraints,
ArrayRef<AffineExprRef> constraints,
ArrayRef<bool> isEq);
// TODO: Helpers for affine map/exprs, etc.
protected:

View File

@ -47,51 +47,45 @@ class MLIRContext;
class IntegerSet {
public:
static IntegerSet *get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> constraints,
ArrayRef<AffineExprRef> constraints,
ArrayRef<bool> eqFlags, MLIRContext *context);
unsigned getNumDims() const { return dimCount; }
unsigned getNumSymbols() const { return symbolCount; }
unsigned getNumOperands() const { return dimCount + symbolCount; }
unsigned getNumConstraints() const { return numConstraints; }
unsigned getNumDims() { return dimCount; }
unsigned getNumSymbols() { return symbolCount; }
unsigned getNumOperands() { return dimCount + symbolCount; }
unsigned getNumConstraints() { return numConstraints; }
ArrayRef<AffineExpr *> getConstraints() const {
return ArrayRef<AffineExpr *>(constraints, numConstraints);
}
ArrayRef<AffineExprRef> getConstraints() { return constraints; }
AffineExpr *getConstraint(unsigned idx) const {
return getConstraints()[idx];
}
AffineExprRef getConstraint(unsigned idx) { return getConstraints()[idx]; }
/// Returns the equality bits, which specify whether each of the constraints
/// is an equality or inequality.
ArrayRef<bool> getEqFlags() const {
return ArrayRef<bool>(eqFlags, numConstraints);
}
ArrayRef<bool> getEqFlags() { return eqFlags; }
/// Returns true if the idx^th constraint is an equality, false if it is an
/// inequality.
bool isEq(unsigned idx) const { return getEqFlags()[idx]; }
bool isEq(unsigned idx) { return getEqFlags()[idx]; }
void print(raw_ostream &os) const;
void dump() const;
void print(raw_ostream &os);
void dump();
private:
IntegerSet(unsigned dimCount, unsigned symbolCount, unsigned numConstraints,
AffineExpr *const *constraints, const bool *const eqFlags);
ArrayRef<AffineExprRef> constraints, ArrayRef<bool> eqFlags);
~IntegerSet() = delete;
const unsigned dimCount;
const unsigned symbolCount;
const unsigned numConstraints;
unsigned dimCount;
unsigned symbolCount;
unsigned numConstraints;
/// Array of affine constraints: a constaint is either an equality
/// (affine_expr == 0) or an inequality (affine_expr >= 0).
AffineExpr *const *const constraints;
ArrayRef<AffineExprRef> constraints;
// Bits to check whether a constraint is an equality or an inequality.
const bool *const eqFlags;
ArrayRef<bool> eqFlags;
};
} // end namespace mlir

View File

@ -28,8 +28,10 @@
#include "llvm/Support/raw_ostream.h"
namespace mlir {
class AffineMap;
class AffineExpr;
template <typename T> class AffineExprBaseRef;
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
class AffineMap;
class Builder;
class Function;
@ -68,8 +70,8 @@ public:
virtual void printType(const Type *type) = 0;
virtual void printFunctionReference(const Function *func) = 0;
virtual void printAttribute(const Attribute *attr) = 0;
virtual void printAffineMap(const AffineMap *map) = 0;
virtual void printAffineExpr(const AffineExpr *expr) = 0;
virtual void printAffineMap(AffineMap *map) = 0;
virtual void printAffineExpr(AffineExprRef expr) = 0;
/// If the specified operation has attributes, print out an attribute
/// dictionary with their values. elidedAttrs allows the client to ignore
@ -104,7 +106,7 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Attribute &attr) {
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const AffineMap &map) {
p.printAffineMap(&map);
p.printAffineMap(&const_cast<AffineMap &>(map));
return p;
}

View File

@ -30,14 +30,14 @@ using namespace mlir;
MutableAffineMap::MutableAffineMap(AffineMap *map, MLIRContext *context)
: numDims(map->getNumDims()), numSymbols(map->getNumSymbols()),
context(context) {
for (auto *result : map->getResults())
for (auto result : map->getResults())
results.push_back(result);
for (auto *rangeSize : map->getRangeSizes())
for (auto rangeSize : map->getRangeSizes())
results.push_back(rangeSize);
}
bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
if (results[idx]->isMultipleOf(factor))
if (const_cast<AffineExprRef &>(results[idx])->isMultipleOf(factor))
return true;
// TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to

View File

@ -38,7 +38,8 @@ getReducedConstBound(const HyperRectangularSet &set, unsigned *idx,
unsigned j = 0;
AffineBoundExprList::const_iterator it, e;
for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
if (auto *cExpr = dyn_cast<AffineConstantExpr>(*it)) {
if (auto *cExpr = const_cast<AffineConstantExpr *>(
dyn_cast<AffineConstantExpr>(*it))) {
if (val == None) {
val = cExpr->getValue();
*idx = j;

View File

@ -85,7 +85,7 @@ AffineExprRef AffineBinaryOpExpr::getMod(AffineExprRef lhs, uint64_t rhs,
/// Returns true if this expression is made out of only symbols and
/// constants (no dimensional identifiers).
bool AffineExpr::isSymbolicOrConstant() const {
bool AffineExpr::isSymbolicOrConstant() {
switch (getKind()) {
case Kind::Constant:
return true;
@ -108,7 +108,7 @@ bool AffineExpr::isSymbolicOrConstant() const {
/// Returns true if this is a pure affine expression, i.e., multiplication,
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
bool AffineExpr::isPureAffine() const {
bool AffineExpr::isPureAffine() {
switch (getKind()) {
case Kind::SymbolId:
case Kind::DimId:
@ -138,7 +138,7 @@ bool AffineExpr::isPureAffine() const {
}
/// Returns the greatest known integral divisor of this affine expression.
uint64_t AffineExpr::getLargestKnownDivisor() const {
uint64_t AffineExpr::getLargestKnownDivisor() {
AffineBinaryOpExpr *binExpr = nullptr;
switch (kind) {
case Kind::SymbolId:
@ -148,7 +148,7 @@ uint64_t AffineExpr::getLargestKnownDivisor() const {
case Kind::Constant:
return std::abs(cast<AffineConstantExpr>(this)->getValue());
case Kind::Mul: {
binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
binExpr = cast<AffineBinaryOpExpr>(this);
return binExpr->getLHS()->getLargestKnownDivisor() *
binExpr->getRHS()->getLargestKnownDivisor();
}
@ -157,7 +157,7 @@ uint64_t AffineExpr::getLargestKnownDivisor() const {
case Kind::FloorDiv:
case Kind::CeilDiv:
case Kind::Mod: {
binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
binExpr = cast<AffineBinaryOpExpr>(this);
return llvm::GreatestCommonDivisor64(
binExpr->getLHS()->getLargestKnownDivisor(),
binExpr->getRHS()->getLargestKnownDivisor());
@ -165,7 +165,7 @@ uint64_t AffineExpr::getLargestKnownDivisor() const {
}
}
bool AffineExpr::isMultipleOf(int64_t factor) const {
bool AffineExpr::isMultipleOf(int64_t factor) {
AffineBinaryOpExpr *binExpr = nullptr;
uint64_t l, u;
switch (kind) {
@ -176,7 +176,7 @@ bool AffineExpr::isMultipleOf(int64_t factor) const {
case Kind::Constant:
return cast<AffineConstantExpr>(this)->getValue() % factor == 0;
case Kind::Mul: {
binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
binExpr = cast<AffineBinaryOpExpr>(this);
// It's probably not worth optimizing this further (to not traverse the
// whole sub-tree under - it that would require a version of isMultipleOf
// that on a 'false' return also returns the largest known divisor).
@ -188,7 +188,7 @@ bool AffineExpr::isMultipleOf(int64_t factor) const {
case Kind::FloorDiv:
case Kind::CeilDiv:
case Kind::Mod: {
binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
binExpr = cast<AffineBinaryOpExpr>(this);
return llvm::GreatestCommonDivisor64(
binExpr->getLHS()->getLargestKnownDivisor(),
binExpr->getRHS()->getLargestKnownDivisor()) %
@ -198,7 +198,7 @@ bool AffineExpr::isMultipleOf(int64_t factor) const {
}
}
MLIRContext *AffineExpr::getContext() const { return context; }
MLIRContext *AffineExpr::getContext() { return context; }
template <> AffineExprRef AffineExprRef::operator+(int64_t v) const {
return AffineBinaryOpExpr::getAdd(expr, v, expr->getContext());

View File

@ -23,7 +23,8 @@
using namespace mlir;
AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
AffineExpr *const *results, AffineExpr *const *rangeSizes)
ArrayRef<AffineExprRef> results,
ArrayRef<AffineExprRef> rangeSizes)
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
results(results), rangeSizes(rangeSizes) {}
@ -33,30 +34,36 @@ AffineMap *AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
{AffineConstantExpr::get(val, context)}, {}, context);
}
bool AffineMap::isIdentity() const {
bool AffineMap::isIdentity() {
if (getNumDims() != getNumResults())
return false;
ArrayRef<AffineExpr *> results = getResults();
ArrayRef<AffineExprRef> results = getResults();
for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
auto *expr = dyn_cast<AffineDimExpr>(results[i]);
auto *expr =
const_cast<AffineDimExpr *>(dyn_cast<AffineDimExpr>(results[i]));
if (!expr || expr->getPosition() != i)
return false;
}
return true;
}
bool AffineMap::isSingleConstant() const {
bool AffineMap::isSingleConstant() {
return getNumResults() == 1 && isa<AffineConstantExpr>(getResult(0));
}
int64_t AffineMap::getSingleConstantResult() const {
int64_t AffineMap::getSingleConstantResult() {
assert(isSingleConstant() && "map must have a single constant result");
return cast<AffineConstantExpr>(getResult(0))->getValue();
return const_cast<AffineConstantExpr *>(
cast<AffineConstantExpr>(getResult(0)))
->getValue();
}
AffineExprRef AffineMap::getResult(unsigned idx) { return results[idx]; }
/// Simplify add expression. Return nullptr if it can't be simplified.
AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context) {
AffineExprRef AffineBinaryOpExpr::simplifyAdd(AffineExprRef lhs,
AffineExprRef rhs,
MLIRContext *context) {
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
@ -80,16 +87,19 @@ AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
return lhs;
}
// Fold successive additions like (d0 + 2) + 3 into d0 + 5.
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
auto *lBin =
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
if (lBin && rhsConst && lBin->getKind() == Kind::Add) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
if (auto *lrhs = const_cast<AffineConstantExpr *>(
dyn_cast<AffineConstantExpr>(lBin->getRHS())))
return lBin->getLHS() + (lrhs->getValue() + rhsConst->getValue());
}
// When doing successive additions, bring constant to the right: turn (d0 + 2)
// + d1 into (d0 + d1) + 2.
if (lBin && lBin->getKind() == Kind::Add) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
if (auto *lrhs = const_cast<AffineConstantExpr *>(
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
return lBin->getLHS() + rhs + lrhs;
}
}
@ -98,8 +108,9 @@ AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
}
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context) {
AffineExprRef AffineBinaryOpExpr::simplifyMul(AffineExprRef lhs,
AffineExprRef rhs,
MLIRContext *context) {
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
@ -129,16 +140,19 @@ AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
}
// Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
auto *lBin =
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
if (lBin && rhsConst && lBin->getKind() == Kind::Mul) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
if (auto *lrhs = const_cast<AffineConstantExpr *>(
dyn_cast<AffineConstantExpr>(lBin->getRHS())))
return lBin->getLHS() * (lrhs->getValue() * rhsConst->getValue());
}
// When doing successive multiplication, bring constant to the right: turn (d0
// * 2) * d1 into (d0 * d1) * 2.
if (lBin && lBin->getKind() == Kind::Mul) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
if (auto *lrhs = const_cast<AffineConstantExpr *>(
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
return (lBin->getLHS() * rhs) * lrhs;
}
}
@ -146,9 +160,9 @@ AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
return nullptr;
}
AffineExpr *AffineBinaryOpExpr::simplifyFloorDiv(AffineExpr *lhs,
AffineExpr *rhs,
MLIRContext *context) {
AffineExprRef AffineBinaryOpExpr::simplifyFloorDiv(AffineExprRef lhs,
AffineExprRef rhs,
MLIRContext *context) {
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
@ -162,9 +176,11 @@ AffineExpr *AffineBinaryOpExpr::simplifyFloorDiv(AffineExpr *lhs,
if (rhsConst->getValue() == 1)
return lhs;
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
auto *lBin =
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
if (lBin && lBin->getKind() == Kind::Mul) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
if (auto *lrhs = const_cast<AffineConstantExpr *>(
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
// rhsConst is known to be positive if a constant.
if (lrhs->getValue() % rhsConst->getValue() == 0)
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
@ -175,9 +191,9 @@ AffineExpr *AffineBinaryOpExpr::simplifyFloorDiv(AffineExpr *lhs,
return nullptr;
}
AffineExpr *AffineBinaryOpExpr::simplifyCeilDiv(AffineExpr *lhs,
AffineExpr *rhs,
MLIRContext *context) {
AffineExprRef AffineBinaryOpExpr::simplifyCeilDiv(AffineExprRef lhs,
AffineExprRef rhs,
MLIRContext *context) {
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
@ -191,9 +207,11 @@ AffineExpr *AffineBinaryOpExpr::simplifyCeilDiv(AffineExpr *lhs,
if (rhsConst->getValue() == 1)
return lhs;
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
auto *lBin =
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
if (lBin && lBin->getKind() == Kind::Mul) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
if (auto *lrhs = const_cast<AffineConstantExpr *>(
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
// rhsConst is known to be positive if a constant.
if (lrhs->getValue() % rhsConst->getValue() == 0)
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
@ -204,8 +222,9 @@ AffineExpr *AffineBinaryOpExpr::simplifyCeilDiv(AffineExpr *lhs,
return nullptr;
}
AffineExpr *AffineBinaryOpExpr::simplifyMod(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context) {
AffineExprRef AffineBinaryOpExpr::simplifyMod(AffineExprRef lhs,
AffineExprRef rhs,
MLIRContext *context) {
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);

View File

@ -64,7 +64,7 @@ public:
// Initializes module state, populating affine map state.
void initialize(const Module *module);
int getAffineMapId(const AffineMap *affineMap) const {
int getAffineMapId(AffineMap *affineMap) const {
auto it = affineMapIds.find(affineMap);
if (it == affineMapIds.end()) {
return -1;
@ -72,9 +72,9 @@ public:
return it->second;
}
ArrayRef<const AffineMap *> getAffineMapIds() const { return affineMapsById; }
ArrayRef<AffineMap *> getAffineMapIds() const { return affineMapsById; }
int getIntegerSetId(const IntegerSet *integerSet) const {
int getIntegerSetId(IntegerSet *integerSet) const {
auto it = integerSetIds.find(integerSet);
if (it == integerSetIds.end()) {
return -1;
@ -82,19 +82,17 @@ public:
return it->second;
}
ArrayRef<const IntegerSet *> getIntegerSetIds() const {
return integerSetsById;
}
ArrayRef<IntegerSet *> getIntegerSetIds() const { return integerSetsById; }
private:
void recordAffineMapReference(const AffineMap *affineMap) {
void recordAffineMapReference(AffineMap *affineMap) {
if (affineMapIds.count(affineMap) == 0) {
affineMapIds[affineMap] = affineMapsById.size();
affineMapsById.push_back(affineMap);
}
}
void recordIntegerSetReference(const IntegerSet *integerSet) {
void recordIntegerSetReference(IntegerSet *integerSet) {
if (integerSetIds.count(integerSet) == 0) {
integerSetIds[integerSet] = integerSetsById.size();
integerSetsById.push_back(integerSet);
@ -102,7 +100,7 @@ private:
}
// Return true if this map could be printed using the shorthand form.
static bool hasShorthandForm(const AffineMap *boundMap) {
static bool hasShorthandForm(AffineMap *boundMap) {
if (boundMap->isSingleConstant())
return true;
@ -126,11 +124,11 @@ private:
void visitAttribute(const Attribute *attr);
void visitOperation(const Operation *op);
DenseMap<const AffineMap *, int> affineMapIds;
std::vector<const AffineMap *> affineMapsById;
DenseMap<AffineMap *, int> affineMapIds;
std::vector<AffineMap *> affineMapsById;
DenseMap<const IntegerSet *, int> integerSetIds;
std::vector<const IntegerSet *> integerSetsById;
DenseMap<IntegerSet *, int> integerSetIds;
std::vector<IntegerSet *> integerSetsById;
};
} // end anonymous namespace
@ -275,10 +273,10 @@ public:
void print(const CFGFunction *fn);
void print(const MLFunction *fn);
void printAffineMap(const AffineMap *map);
void printAffineExpr(const AffineExpr *expr);
void printAffineConstraint(const AffineExpr *expr, bool isEq);
void printIntegerSet(const IntegerSet *set);
void printAffineMap(AffineMap *map);
void printAffineExpr(AffineExprRef expr);
void printAffineConstraint(AffineExprRef expr, bool isEq);
void printIntegerSet(IntegerSet *set);
protected:
raw_ostream &os;
@ -290,9 +288,9 @@ protected:
ArrayRef<const char *> elidedAttrs = {});
void printFunctionResultType(const FunctionType *type);
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(const AffineMap *affineMap);
void printAffineMapReference(AffineMap *affineMap);
void printIntegerSetId(int integerSetId) const;
void printIntegerSetReference(const IntegerSet *integerSet);
void printIntegerSetReference(IntegerSet *integerSet);
/// This enum is used to represent the binding stength of the enclosing
/// context that an AffineExpr is being printed in, so we can intelligently
@ -301,7 +299,7 @@ protected:
Weak, // + and -
Strong, // All other binary operators.
};
void printAffineExprInternal(const AffineExpr *expr,
void printAffineExprInternal(AffineExprRef expr,
BindingStrength enclosingTightness);
};
} // end anonymous namespace
@ -323,7 +321,7 @@ void ModulePrinter::printAffineMapId(int affineMapId) const {
os << "#map" << affineMapId;
}
void ModulePrinter::printAffineMapReference(const AffineMap *affineMap) {
void ModulePrinter::printAffineMapReference(AffineMap *affineMap) {
int mapId = state.getAffineMapId(affineMap);
if (mapId >= 0) {
// Map will be printed at top of module so print reference to its id.
@ -339,7 +337,7 @@ void ModulePrinter::printIntegerSetId(int integerSetId) const {
os << "@@set" << integerSetId;
}
void ModulePrinter::printIntegerSetReference(const IntegerSet *integerSet) {
void ModulePrinter::printIntegerSetReference(IntegerSet *integerSet) {
int setId;
if ((setId = state.getIntegerSetId(integerSet)) >= 0) {
// The set will be printed at top of module; so print reference to its id.
@ -572,12 +570,12 @@ void ModulePrinter::printType(const Type *type) {
// Affine expressions and maps
//===----------------------------------------------------------------------===//
void ModulePrinter::printAffineExpr(const AffineExpr *expr) {
void ModulePrinter::printAffineExpr(AffineExprRef expr) {
printAffineExprInternal(expr, BindingStrength::Weak);
}
void ModulePrinter::printAffineExprInternal(
const AffineExpr *expr, BindingStrength enclosingTightness) {
AffineExprRef expr, BindingStrength enclosingTightness) {
const char *binopSpelling = nullptr;
switch (expr->getKind()) {
case AffineExpr::Kind::SymbolId:
@ -628,10 +626,10 @@ void ModulePrinter::printAffineExprInternal(
// Pretty print addition to a product that has a negative operand as a
// subtraction.
AffineExpr *rhsExpr = binOp->getRHS();
AffineExprRef rhsExpr = binOp->getRHS();
if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(rhsExpr)) {
if (rhs->getKind() == AffineExpr::Kind::Mul) {
AffineExpr *rrhsExpr = rhs->getRHS();
AffineExprRef rrhsExpr = rhs->getRHS();
if (auto *rrhs = dyn_cast<AffineConstantExpr>(rrhsExpr)) {
if (rrhs->getValue() == -1) {
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
@ -675,12 +673,12 @@ void ModulePrinter::printAffineExprInternal(
os << ')';
}
void ModulePrinter::printAffineConstraint(const AffineExpr *expr, bool isEq) {
void ModulePrinter::printAffineConstraint(AffineExprRef expr, bool isEq) {
printAffineExprInternal(expr, BindingStrength::Weak);
isEq ? os << " == 0" : os << " >= 0";
}
void ModulePrinter::printAffineMap(const AffineMap *map) {
void ModulePrinter::printAffineMap(AffineMap *map) {
// Dimension identifiers.
os << '(';
for (int i = 0; i < (int)map->getNumDims() - 1; ++i)
@ -704,7 +702,7 @@ void ModulePrinter::printAffineMap(const AffineMap *map) {
// Result affine expressions.
os << " -> (";
interleaveComma(map->getResults(),
[&](AffineExpr *expr) { printAffineExpr(expr); });
[&](AffineExprRef expr) { printAffineExpr(expr); });
os << ')';
if (!map->isBounded()) {
@ -714,11 +712,11 @@ void ModulePrinter::printAffineMap(const AffineMap *map) {
// Print range sizes for bounded affine maps.
os << " size (";
interleaveComma(map->getRangeSizes(),
[&](AffineExpr *expr) { printAffineExpr(expr); });
[&](AffineExprRef expr) { printAffineExpr(expr); });
os << ')';
}
void ModulePrinter::printIntegerSet(const IntegerSet *set) {
void ModulePrinter::printIntegerSet(IntegerSet *set) {
// Dimension identifiers.
os << '(';
for (unsigned i = 1; i < set->getNumDims(); ++i)
@ -853,13 +851,13 @@ public:
void printAttribute(const Attribute *attr) {
ModulePrinter::printAttribute(attr);
}
void printAffineMap(const AffineMap *map) {
void printAffineMap(AffineMap *map) {
return ModulePrinter::printAffineMapReference(map);
}
void printIntegerSet(const IntegerSet *set) {
void printIntegerSet(IntegerSet *set) {
return ModulePrinter::printIntegerSetReference(set);
}
void printAffineExpr(const AffineExpr *expr) {
void printAffineExpr(AffineExprRef expr) {
return ModulePrinter::printAffineExpr(expr);
}
void printFunctionReference(const Function *func) {
@ -1433,7 +1431,7 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
// Therefore, short-hand parsing and printing is only supported for
// zero-operand constant maps and single symbol operand identity maps.
if (map->getNumResults() == 1) {
AffineExpr *expr = map->getResult(0);
AffineExprRef expr = map->getResult(0);
// Print constant bound.
if (map->getNumDims() == 0 && map->getNumSymbols() == 0) {
@ -1498,32 +1496,32 @@ void Type::print(raw_ostream &os) const {
void Type::dump() const { print(llvm::errs()); }
void AffineMap::dump() const {
void AffineMap::dump() {
print(llvm::errs());
llvm::errs() << "\n";
}
void AffineExpr::dump() const {
void AffineExpr::dump() {
print(llvm::errs());
llvm::errs() << "\n";
}
void IntegerSet::dump() const {
void IntegerSet::dump() {
print(llvm::errs());
llvm::errs() << "\n";
}
void AffineExpr::print(raw_ostream &os) const {
void AffineExpr::print(raw_ostream &os) {
ModuleState state(/*no context is known*/ nullptr);
ModulePrinter(os, state).printAffineExpr(this);
}
void AffineMap::print(raw_ostream &os) const {
void AffineMap::print(raw_ostream &os) {
ModuleState state(/*no context is known*/ nullptr);
ModulePrinter(os, state).printAffineMap(this);
}
void IntegerSet::print(raw_ostream &os) const {
void IntegerSet::print(raw_ostream &os) {
ModuleState state(/*no context is known*/ nullptr);
ModulePrinter(os, state).printIntegerSet(this);
}

View File

@ -150,8 +150,8 @@ FunctionAttr *Builder::getFunctionAttr(const Function *value) {
//===----------------------------------------------------------------------===//
AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> results,
ArrayRef<AffineExpr *> rangeSizes) {
ArrayRef<AffineExprRef> results,
ArrayRef<AffineExprRef> rangeSizes) {
return AffineMap::get(dimCount, symbolCount, results, rangeSizes, context);
}
@ -224,7 +224,7 @@ AffineExprRef Builder::getCeilDivExpr(AffineExprRef lhs, uint64_t rhs) {
}
IntegerSet *Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> constraints,
ArrayRef<AffineExprRef> constraints,
ArrayRef<bool> isEq) {
return IntegerSet::get(dimCount, symbolCount, constraints, isEq, context);
}
@ -251,9 +251,9 @@ AffineMap *Builder::getSingleDimShiftAffineMap(int64_t shift) {
}
AffineMap *Builder::getShiftedAffineMap(AffineMap *map, int64_t shift) {
SmallVector<AffineExpr *, 4> shiftedResults;
SmallVector<AffineExprRef, 4> shiftedResults;
shiftedResults.reserve(map->getNumResults());
for (auto *resultExpr : map->getResults()) {
for (auto resultExpr : map->getResults()) {
shiftedResults.push_back(getAddExpr(resultExpr, shift));
}
return AffineMap::get(map->getNumDims(), map->getNumSymbols(), shiftedResults,

View File

@ -22,8 +22,9 @@
using namespace mlir;
IntegerSet::IntegerSet(unsigned dimCount, unsigned symbolCount,
unsigned numConstraints, AffineExpr *const *constraints,
const bool *const eqFlags)
unsigned numConstraints,
ArrayRef<AffineExprRef> constraints,
ArrayRef<bool> eqFlags)
: dimCount(dimCount), symbolCount(symbolCount),
numConstraints(numConstraints), constraints(constraints),
eqFlags(eqFlags) {}

View File

@ -59,8 +59,8 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
// Affine maps are uniqued based on their dim/symbol counts and affine
// expressions.
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>,
ArrayRef<AffineExpr *>>;
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExprRef>,
ArrayRef<AffineExprRef>>;
using DenseMapInfo<AffineMap *>::getHashValue;
using DenseMapInfo<AffineMap *>::isEqual;
@ -71,7 +71,7 @@ struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
}
static bool isEqual(const KeyTy &lhs, const AffineMap *rhs) {
static bool isEqual(const KeyTy &lhs, AffineMap *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
@ -224,7 +224,7 @@ public:
// Affine binary op expression uniquing. Figure out uniquing of dimensional
// or symbolic identifiers.
DenseMap<std::tuple<unsigned, AffineExpr *, AffineExpr *>, AffineExpr *>
DenseMap<std::tuple<unsigned, AffineExprRef, AffineExprRef>, AffineExprRef>
affineExprs;
// Uniqui'ing of AffineDimExpr, AffineSymbolExpr's by their position.
@ -800,8 +800,8 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
//===----------------------------------------------------------------------===//
AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> results,
ArrayRef<AffineExpr *> rangeSizes,
ArrayRef<AffineExprRef> results,
ArrayRef<AffineExprRef> rangeSizes,
MLIRContext *context) {
// The number of results can't be zero.
assert(!results.empty());
@ -822,12 +822,12 @@ AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
auto *res = impl.allocator.Allocate<AffineMap>();
// Copy the results and range sizes into the bump pointer.
results = impl.copyInto(ArrayRef<AffineExpr *>(results));
rangeSizes = impl.copyInto(ArrayRef<AffineExpr *>(rangeSizes));
results = impl.copyInto(results);
rangeSizes = impl.copyInto(rangeSizes);
// Initialize the memory using placement new.
new (res) AffineMap(dimCount, symbolCount, results.size(), results.data(),
rangeSizes.empty() ? nullptr : rangeSizes.data());
new (res)
AffineMap(dimCount, symbolCount, results.size(), results, rangeSizes);
// Cache and return it.
return *existing.first = res;
@ -843,15 +843,13 @@ AffineExprRef AffineBinaryOpExpr::get(AffineExpr::Kind kind, AffineExprRef lhs,
auto &impl = context->getImpl();
// Check if we already have this affine expression, and return it if we do.
AffineExpr *lhsExpr = lhs;
AffineExpr *rhsExpr = rhs;
auto keyValue = std::make_tuple((unsigned)kind, lhsExpr, rhsExpr);
auto keyValue = std::make_tuple((unsigned)kind, lhs, rhs);
auto cached = impl.affineExprs.find(keyValue);
if (cached != impl.affineExprs.end())
return cached->second;
// Simplify the expression if possible.
AffineExpr *simplified;
AffineExprRef simplified(nullptr);
switch (kind) {
case Kind::Add:
simplified = AffineBinaryOpExpr::simplifyAdd(lhs, rhs, context);
@ -940,7 +938,7 @@ AffineExprRef AffineConstantExpr::get(int64_t constant, MLIRContext *context) {
//===----------------------------------------------------------------------===//
IntegerSet *IntegerSet::get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> constraints,
ArrayRef<AffineExprRef> constraints,
ArrayRef<bool> eqFlags, MLIRContext *context) {
assert(eqFlags.size() == constraints.size());
@ -950,10 +948,10 @@ IntegerSet *IntegerSet::get(unsigned dimCount, unsigned symbolCount,
auto *res = impl.allocator.Allocate<IntegerSet>();
// Copy the equalities and inequalities into the bump pointer.
constraints = impl.copyInto(ArrayRef<AffineExpr *>(constraints));
constraints = impl.copyInto(ArrayRef<AffineExprRef>(constraints));
eqFlags = impl.copyInto(ArrayRef<bool>(eqFlags));
// Initialize the memory using placement new.
return new (res) IntegerSet(dimCount, symbolCount, constraints.size(),
constraints.data(), eqFlags.data());
constraints, eqFlags);
}

View File

@ -193,7 +193,7 @@ public:
MLIRContext *context)
: numDims(numDims), operandConsts(operandConsts), context(context) {}
IntegerAttr *constantFold(AffineExpr *expr) {
IntegerAttr *constantFold(AffineExprRef expr) {
switch (expr->getKind()) {
case AffineExpr::Kind::Add:
return constantFoldBinExpr(
@ -224,7 +224,7 @@ public:
private:
IntegerAttr *
constantFoldBinExpr(AffineExpr *expr,
constantFoldBinExpr(AffineExprRef expr,
std::function<uint64_t(int64_t, uint64_t)> op) {
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
auto *lhs = constantFold(binOpExpr->getLHS());
@ -254,7 +254,7 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operands,
AffineExprConstantFolder exprFolder(map->getNumDims(), operands, context);
// Constant fold each AffineExpr in AffineMap and add to 'results'.
for (auto *expr : map->getResults()) {
for (auto expr : map->getResults()) {
results.push_back(exprFolder.constantFold(expr));
}
// Return false on success.

View File

@ -1238,7 +1238,7 @@ AffineMap *AffineParser::parseAffineMapInline() {
parseToken(Token::l_paren, "expected '(' at start of affine map range"))
return nullptr;
SmallVector<AffineExpr *, 4> exprs;
SmallVector<AffineExprRef, 4> exprs;
auto parseElt = [&]() -> ParseResult {
auto *elt = parseAffineExpr();
ParseResult res = elt ? ParseSuccess : ParseFailure;
@ -1257,7 +1257,7 @@ AffineMap *AffineParser::parseAffineMapInline() {
// dim-size ::= affine-expr | `min` `(` affine-expr (`,` affine-expr)+ `)`
// TODO(bondhugula): support for min of several affine expressions.
// TODO: check if sizes are non-negative whenever they are constant.
SmallVector<AffineExpr *, 4> rangeSizes;
SmallVector<AffineExprRef, 4> rangeSizes;
if (consumeIf(Token::kw_size)) {
// Location of the l_paren token (if it exists) for error reporting later.
auto loc = getToken().getLoc();
@ -2500,7 +2500,7 @@ IntegerSet *AffineParser::parseIntegerSetInline() {
"expected '(' at start of integer set constraint list"))
return nullptr;
SmallVector<AffineExpr *, 4> constraints;
SmallVector<AffineExprRef, 4> constraints;
SmallVector<bool, 4> isEqs;
auto parseElt = [&]() -> ParseResult {
bool isEq;

View File

@ -52,9 +52,7 @@ FunctionPass *mlir::createSimplifyAffineExprPass() {
}
AffineMap *MutableAffineMap::getAffineMap() {
SmallVector<AffineExpr *, 8> res(results.begin(), results.end());
SmallVector<AffineExpr *, 8> sizes(rangeSizes.begin(), rangeSizes.end());
return AffineMap::get(numDims, numSymbols, res, sizes, context);
return AffineMap::get(numDims, numSymbols, results, rangeSizes, context);
}
PassResult SimplifyAffineExpr::runOnMLFunction(MLFunction *f) {