forked from OSchip/llvm-project
[RFC][MLIR] Use AffineExprRef in place of AffineExpr* in IR
This CL starts by replacing AffineExpr* with value-type AffineExprRef in a few places in the IR. By a domino effect that is pretty telling of the inconsistencies in the codebase, const is removed where it makes sense. The rationale is that the decision was concisously made that unique'd types have pointer semantics without const specifier. This is fine but we should be consistent. In the end, the only logical invariant is that there should never be such a thing as a const AffineExpr*, const AffineMap* or const IntegerSet* in our codebase. This CL takes a number of shortcuts to killing const with fire, in particular forcing const AffineExprRef to return the underlying non-const AffineExpr*. This will be removed once AffineExpr* has disappeared in containers but for now such shortcuts allow a bit of sanity in this long quest for cleanups. The **only** places where const AffineExpr*, const AffineMap* or const IntegerSet* may still appear is by transitive needs from containers, comparison operators etc. There is still one major thing remaining here: figure out why cast/dyn_cast return me a const AffineXXX*, which in turn requires a bunch of ugly const_casts. I suspect this is due to the classof taking const AffineXXXExpr*. I wonder whether this is a side effect of 1., if it is coming from llvm itself (I'd doubt it) or something else (clattner@?) In light of this, the whole discussion about const makes total sense to me now and I would systematically apply the rule that in the end, we should never have any const XXX in our codebase for unique'd types (assuming we can remove them all in containers and no additional constness constraint is added on us from the outside world). PiperOrigin-RevId: 215811554
This commit is contained in:
parent
5b8017db18
commit
b55b407601
|
@ -24,6 +24,7 @@
|
||||||
#define MLIR_IR_AFFINE_EXPR_H
|
#define MLIR_IR_AFFINE_EXPR_H
|
||||||
|
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "llvm/ADT/DenseMapInfo.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -58,26 +59,26 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Return the classification for this type.
|
/// Return the classification for this type.
|
||||||
Kind getKind() const { return kind; }
|
Kind getKind() { return kind; }
|
||||||
|
|
||||||
void print(raw_ostream &os) const;
|
void print(raw_ostream &os);
|
||||||
void dump() const;
|
void dump();
|
||||||
|
|
||||||
/// Returns true if this expression is made out of only symbols and
|
/// Returns true if this expression is made out of only symbols and
|
||||||
/// constants, i.e., it does not involve dimensional identifiers.
|
/// constants, i.e., it does not involve dimensional identifiers.
|
||||||
bool isSymbolicOrConstant() const;
|
bool isSymbolicOrConstant();
|
||||||
|
|
||||||
/// Returns true if this is a pure affine expression, i.e., multiplication,
|
/// Returns true if this is a pure affine expression, i.e., multiplication,
|
||||||
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
|
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
|
||||||
bool isPureAffine() const;
|
bool isPureAffine();
|
||||||
|
|
||||||
/// Returns the greatest known integral divisor of this affine expression.
|
/// Returns the greatest known integral divisor of this affine expression.
|
||||||
uint64_t getLargestKnownDivisor() const;
|
uint64_t getLargestKnownDivisor();
|
||||||
|
|
||||||
/// Return true if the affine expression is a multiple of 'factor'.
|
/// Return true if the affine expression is a multiple of 'factor'.
|
||||||
bool isMultipleOf(int64_t factor) const;
|
bool isMultipleOf(int64_t factor);
|
||||||
|
|
||||||
MLIRContext *getContext() const;
|
MLIRContext *getContext();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit AffineExpr(Kind kind, MLIRContext *context)
|
explicit AffineExpr(Kind kind, MLIRContext *context)
|
||||||
|
@ -93,7 +94,7 @@ private:
|
||||||
MLIRContext *context;
|
MLIRContext *context;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline raw_ostream &operator<<(raw_ostream &os, const AffineExpr &expr) {
|
inline raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr) {
|
||||||
expr.print(os);
|
expr.print(os);
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
@ -104,25 +105,22 @@ inline raw_ostream &operator<<(raw_ostream &os, const AffineExpr &expr) {
|
||||||
/// away.
|
/// away.
|
||||||
/// TODO(ntv): Remove all redundant MLIRContext* arguments through the API
|
/// TODO(ntv): Remove all redundant MLIRContext* arguments through the API
|
||||||
/// TODO(ntv): Remove all uses of AffineExpr* in Parser.cpp
|
/// TODO(ntv): Remove all uses of AffineExpr* in Parser.cpp
|
||||||
/// TODO(ntv): Remove all uses of AffineExpr* in learning/brain
|
|
||||||
/// TODO(ntv): Remove all uses of AffineExpr* in IR
|
|
||||||
/// TODO(ntv): Add extra out-of-class operators for int op AffineExprBaseRef
|
/// TODO(ntv): Add extra out-of-class operators for int op AffineExprBaseRef
|
||||||
/// TODO(ntv): Rename
|
/// TODO(ntv): Rename
|
||||||
/// TODO(ntv): Drop const everywhere it makes sense in AffineExpr
|
/// TODO(ntv): Drop const everywhere it makes sense in AffineExpr
|
||||||
/// TODO(ntv): remove const comment
|
/// TODO(ntv): remove const comment
|
||||||
|
/// TODO(ntv): pointer pair
|
||||||
template <typename AffineExprType> class AffineExprBaseRef {
|
template <typename AffineExprType> class AffineExprBaseRef {
|
||||||
public:
|
public:
|
||||||
/* implicit */ AffineExprBaseRef(const AffineExprType *expr)
|
/* implicit */ AffineExprBaseRef(AffineExprType *expr) : expr(expr) {}
|
||||||
: expr(const_cast<AffineExprType *>(expr)) {}
|
|
||||||
|
|
||||||
AffineExprBaseRef(const AffineExprBaseRef &other) : expr(other.expr){};
|
AffineExprBaseRef(const AffineExprBaseRef &other) : expr(other.expr){};
|
||||||
AffineExprBaseRef &operator=(AffineExprBaseRef other) {
|
AffineExprBaseRef &operator=(AffineExprBaseRef other) {
|
||||||
expr = other;
|
expr = other;
|
||||||
return *this;
|
return *this;
|
||||||
};
|
};
|
||||||
bool operator==(AffineExprBaseRef other) { return expr == other; };
|
bool operator==(AffineExprBaseRef other) const { return expr == other.expr; };
|
||||||
AffineExprType *operator->() { return expr; }
|
AffineExprType *operator->() { return expr; }
|
||||||
AffineExprType const *operator->() const { return expr; }
|
|
||||||
/* implicit */ operator AffineExprType *() { return expr; }
|
/* implicit */ operator AffineExprType *() { return expr; }
|
||||||
|
|
||||||
bool operator!() { return expr == nullptr; }
|
bool operator!() { return expr == nullptr; }
|
||||||
|
@ -141,16 +139,13 @@ public:
|
||||||
AffineExprBaseRef operator%(uint64_t v) const;
|
AffineExprBaseRef operator%(uint64_t v) const;
|
||||||
AffineExprBaseRef operator%(AffineExprBaseRef other) const;
|
AffineExprBaseRef operator%(AffineExprBaseRef other) const;
|
||||||
|
|
||||||
friend AffineExprType *
|
|
||||||
llvm::simplify_type<AffineExprBaseRef<AffineExprType>>::getSimplifiedValue(
|
|
||||||
AffineExprBaseRef<AffineExprType> &input);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AffineExprType *expr;
|
AffineExprType *expr;
|
||||||
};
|
};
|
||||||
|
|
||||||
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
|
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
|
||||||
|
|
||||||
|
inline ::llvm::hash_code hash_value(AffineExprRef arg);
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
namespace llvm {
|
namespace llvm {
|
||||||
|
@ -160,7 +155,25 @@ namespace llvm {
|
||||||
template <typename T> struct simplify_type<mlir::AffineExprBaseRef<T>> {
|
template <typename T> struct simplify_type<mlir::AffineExprBaseRef<T>> {
|
||||||
using SimpleType = T *;
|
using SimpleType = T *;
|
||||||
static SimpleType getSimplifiedValue(mlir::AffineExprBaseRef<T> &input) {
|
static SimpleType getSimplifiedValue(mlir::AffineExprBaseRef<T> &input) {
|
||||||
return input.expr;
|
return input;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// AffineExprRef hash just like pointers
|
||||||
|
template <> struct DenseMapInfo<mlir::AffineExprRef> {
|
||||||
|
static mlir::AffineExprRef getEmptyKey() {
|
||||||
|
auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getEmptyKey();
|
||||||
|
return mlir::AffineExprRef(pointer);
|
||||||
|
}
|
||||||
|
static mlir::AffineExprRef getTombstoneKey() {
|
||||||
|
auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getTombstoneKey();
|
||||||
|
return mlir::AffineExprRef(pointer);
|
||||||
|
}
|
||||||
|
static unsigned getHashValue(mlir::AffineExprRef val) {
|
||||||
|
return mlir::hash_value(val);
|
||||||
|
}
|
||||||
|
static bool isEqual(mlir::AffineExprRef LHS, mlir::AffineExprRef RHS) {
|
||||||
|
return LHS == RHS;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -168,6 +181,11 @@ template <typename T> struct simplify_type<mlir::AffineExprBaseRef<T>> {
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
// Make AffineExprRef hashable.
|
||||||
|
inline ::llvm::hash_code hash_value(AffineExprRef arg) {
|
||||||
|
return ::llvm::hash_value(static_cast<AffineExpr *>(arg));
|
||||||
|
}
|
||||||
|
|
||||||
/// Affine binary operation expression. An affine binary operation could be an
|
/// Affine binary operation expression. An affine binary operation could be an
|
||||||
/// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
|
/// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
|
||||||
/// represented through a multiply by -1 and add.) These expressions are always
|
/// represented through a multiply by -1 and add.) These expressions are always
|
||||||
|
@ -212,34 +230,35 @@ public:
|
||||||
static AffineExprRef getMod(AffineExprRef lhs, uint64_t rhs,
|
static AffineExprRef getMod(AffineExprRef lhs, uint64_t rhs,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
|
|
||||||
AffineExprRef getLHS() const { return lhs; }
|
AffineExprRef getLHS() { return lhs; }
|
||||||
AffineExprRef getRHS() const { return rhs; }
|
AffineExprRef getRHS() { return rhs; }
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const AffineExpr *expr) {
|
static bool classof(const AffineExpr *expr) {
|
||||||
return expr->getKind() <= Kind::LAST_AFFINE_BINARY_OP;
|
return const_cast<AffineExpr *>(expr)->getKind() <=
|
||||||
|
Kind::LAST_AFFINE_BINARY_OP;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit AffineBinaryOpExpr(Kind kind, AffineExprRef lhs, AffineExprRef rhs,
|
explicit AffineBinaryOpExpr(Kind kind, AffineExprRef lhs, AffineExprRef rhs,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
|
|
||||||
AffineExpr *const lhs;
|
const AffineExprRef lhs;
|
||||||
AffineExpr *const rhs;
|
const AffineExprRef rhs;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
~AffineBinaryOpExpr() = delete;
|
~AffineBinaryOpExpr() = delete;
|
||||||
// Simplification prior to construction of binary affine op expressions.
|
// Simplification prior to construction of binary affine op expressions.
|
||||||
static AffineExpr *simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
|
static AffineExprRef simplifyAdd(AffineExprRef lhs, AffineExprRef rhs,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
static AffineExpr *simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
|
static AffineExprRef simplifyMul(AffineExprRef lhs, AffineExprRef rhs,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
static AffineExpr *simplifyFloorDiv(AffineExpr *lhs, AffineExpr *rhs,
|
static AffineExprRef simplifyFloorDiv(AffineExprRef lhs, AffineExprRef rhs,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
static AffineExpr *simplifyCeilDiv(AffineExpr *lhs, AffineExpr *rhs,
|
static AffineExprRef simplifyCeilDiv(AffineExprRef lhs, AffineExprRef rhs,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
static AffineExpr *simplifyMod(AffineExpr *lhs, AffineExpr *rhs,
|
static AffineExprRef simplifyMod(AffineExprRef lhs, AffineExprRef rhs,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A dimensional identifier appearing in an affine expression.
|
/// A dimensional identifier appearing in an affine expression.
|
||||||
|
@ -252,11 +271,11 @@ public:
|
||||||
static AffineExprBaseRef<AffineExpr> get(unsigned position,
|
static AffineExprBaseRef<AffineExpr> get(unsigned position,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
|
|
||||||
unsigned getPosition() const { return position; }
|
unsigned getPosition() { return position; }
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const AffineExpr *expr) {
|
static bool classof(const AffineExpr *expr) {
|
||||||
return expr->getKind() == Kind::DimId;
|
return const_cast<AffineExpr *>(expr)->getKind() == Kind::DimId;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -278,11 +297,11 @@ public:
|
||||||
static AffineExprBaseRef<AffineExpr> get(unsigned position,
|
static AffineExprBaseRef<AffineExpr> get(unsigned position,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
|
|
||||||
unsigned getPosition() const { return position; }
|
unsigned getPosition() { return position; }
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const AffineExpr *expr) {
|
static bool classof(const AffineExpr *expr) {
|
||||||
return expr->getKind() == Kind::SymbolId;
|
return const_cast<AffineExpr *>(expr)->getKind() == Kind::SymbolId;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -300,11 +319,11 @@ public:
|
||||||
static AffineExprBaseRef<AffineExpr> get(int64_t constant,
|
static AffineExprBaseRef<AffineExpr> get(int64_t constant,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
|
|
||||||
int64_t getValue() const { return constant; }
|
int64_t getValue() { return constant; }
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const AffineExpr *expr) {
|
static bool classof(const AffineExpr *expr) {
|
||||||
return expr->getKind() == Kind::Constant;
|
return const_cast<AffineExpr *>(expr)->getKind() == Kind::Constant;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -78,7 +78,7 @@ template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
|
||||||
// that you use to visit affine expressions...
|
// that you use to visit affine expressions...
|
||||||
public:
|
public:
|
||||||
// Function to walk an AffineExpr (in post order).
|
// Function to walk an AffineExpr (in post order).
|
||||||
RetTy walkPostOrder(AffineExpr *expr) {
|
RetTy walkPostOrder(AffineExprRef expr) {
|
||||||
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
|
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
|
||||||
"Must instantiate with a derived type of AffineExprVisitor");
|
"Must instantiate with a derived type of AffineExprVisitor");
|
||||||
switch (expr->getKind()) {
|
switch (expr->getKind()) {
|
||||||
|
@ -120,7 +120,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to visit an AffineExpr.
|
// Function to visit an AffineExpr.
|
||||||
RetTy visit(AffineExpr *expr) {
|
RetTy visit(AffineExprRef expr) {
|
||||||
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
|
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
|
||||||
"Must instantiate with a derived type of AffineExprVisitor");
|
"Must instantiate with a derived type of AffineExprVisitor");
|
||||||
switch (expr->getKind()) {
|
switch (expr->getKind()) {
|
||||||
|
|
|
@ -30,8 +30,10 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
class MLIRContext;
|
|
||||||
class AffineExpr;
|
class AffineExpr;
|
||||||
|
template <typename T> class AffineExprBaseRef;
|
||||||
|
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
|
||||||
|
class MLIRContext;
|
||||||
|
|
||||||
/// A multi-dimensional affine map
|
/// A multi-dimensional affine map
|
||||||
/// Affine map's are immutable like Type's, and they are uniqued.
|
/// Affine map's are immutable like Type's, and they are uniqued.
|
||||||
|
@ -41,8 +43,8 @@ class AffineExpr;
|
||||||
class AffineMap {
|
class AffineMap {
|
||||||
public:
|
public:
|
||||||
static AffineMap *get(unsigned dimCount, unsigned symbolCount,
|
static AffineMap *get(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> results,
|
ArrayRef<AffineExprRef> results,
|
||||||
ArrayRef<AffineExpr *> rangeSizes,
|
ArrayRef<AffineExprRef> rangeSizes,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
|
|
||||||
/// Returns a single constant result affine map.
|
/// Returns a single constant result affine map.
|
||||||
|
@ -51,60 +53,56 @@ public:
|
||||||
/// Returns true if the co-domain (or more loosely speaking, range) of this
|
/// Returns true if the co-domain (or more loosely speaking, range) of this
|
||||||
/// map is bounded. Bounded affine maps have a size (extent) for each of
|
/// map is bounded. Bounded affine maps have a size (extent) for each of
|
||||||
/// their range dimensions (more accurately co-domain dimensions).
|
/// their range dimensions (more accurately co-domain dimensions).
|
||||||
bool isBounded() const { return rangeSizes != nullptr; }
|
bool isBounded() { return !rangeSizes.empty(); }
|
||||||
|
|
||||||
/// Returns true if this affine map is an identity affine map.
|
/// Returns true if this affine map is an identity affine map.
|
||||||
/// An identity affine map corresponds to an identity affine function on the
|
/// An identity affine map corresponds to an identity affine function on the
|
||||||
/// dimensional identifiers.
|
/// dimensional identifiers.
|
||||||
bool isIdentity() const;
|
bool isIdentity();
|
||||||
|
|
||||||
/// Returns true if this affine map is a single result constant function.
|
/// Returns true if this affine map is a single result constant function.
|
||||||
bool isSingleConstant() const;
|
bool isSingleConstant();
|
||||||
|
|
||||||
/// Returns the constant result of this map. This methods asserts that the map
|
/// Returns the constant result of this map. This methods asserts that the map
|
||||||
/// has a single constant result.
|
/// has a single constant result.
|
||||||
int64_t getSingleConstantResult() const;
|
int64_t getSingleConstantResult();
|
||||||
|
|
||||||
// Prints affine map to 'os'.
|
// Prints affine map to 'os'.
|
||||||
void print(raw_ostream &os) const;
|
void print(raw_ostream &os);
|
||||||
void dump() const;
|
void dump();
|
||||||
|
|
||||||
unsigned getNumDims() const { return numDims; }
|
unsigned getNumDims() { return numDims; }
|
||||||
unsigned getNumSymbols() const { return numSymbols; }
|
unsigned getNumSymbols() { return numSymbols; }
|
||||||
unsigned getNumResults() const { return numResults; }
|
unsigned getNumResults() { return numResults; }
|
||||||
unsigned getNumInputs() const { return numDims + numSymbols; }
|
unsigned getNumInputs() { return numDims + numSymbols; }
|
||||||
|
|
||||||
ArrayRef<AffineExpr *> getResults() const {
|
ArrayRef<AffineExprRef> getResults() { return results; }
|
||||||
return ArrayRef<AffineExpr *>(results, numResults);
|
|
||||||
}
|
|
||||||
|
|
||||||
AffineExpr *getResult(unsigned idx) const { return results[idx]; }
|
AffineExprRef getResult(unsigned idx);
|
||||||
|
|
||||||
ArrayRef<AffineExpr *> getRangeSizes() const {
|
ArrayRef<AffineExprRef> getRangeSizes() { return rangeSizes; }
|
||||||
return rangeSizes ? ArrayRef<AffineExpr *>(rangeSizes, numResults)
|
|
||||||
: ArrayRef<AffineExpr *>();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
||||||
AffineExpr *const *results, AffineExpr *const *rangeSizes);
|
ArrayRef<AffineExprRef> results,
|
||||||
|
ArrayRef<AffineExprRef> rangeSizes);
|
||||||
|
|
||||||
AffineMap(const AffineMap &) = delete;
|
AffineMap(const AffineMap &) = delete;
|
||||||
void operator=(const AffineMap &) = delete;
|
void operator=(const AffineMap &) = delete;
|
||||||
|
|
||||||
const unsigned numDims;
|
unsigned numDims;
|
||||||
const unsigned numSymbols;
|
unsigned numSymbols;
|
||||||
const unsigned numResults;
|
unsigned numResults;
|
||||||
|
|
||||||
/// The affine expressions for this (multi-dimensional) map.
|
/// The affine expressions for this (multi-dimensional) map.
|
||||||
/// TODO: use trailing objects for this.
|
/// TODO: use trailing objects for this.
|
||||||
AffineExpr *const *const results;
|
ArrayRef<AffineExprRef> results;
|
||||||
|
|
||||||
/// The extents along each of the range dimensions if the map is bounded,
|
/// The extents along each of the range dimensions if the map is bounded,
|
||||||
/// nullptr otherwise.
|
/// nullptr otherwise.
|
||||||
AffineExpr *const *const rangeSizes;
|
ArrayRef<AffineExprRef> rangeSizes;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_IR_AFFINE_MAP_H
|
#endif // MLIR_IR_AFFINE_MAP_H
|
||||||
|
|
|
@ -120,17 +120,9 @@ public:
|
||||||
AffineExprRef getCeilDivExpr(AffineExprRef lhs, AffineExprRef rhs);
|
AffineExprRef getCeilDivExpr(AffineExprRef lhs, AffineExprRef rhs);
|
||||||
AffineExprRef getCeilDivExpr(AffineExprRef lhs, uint64_t rhs);
|
AffineExprRef getCeilDivExpr(AffineExprRef lhs, uint64_t rhs);
|
||||||
|
|
||||||
/// Creates a sum of products affine expression from constant coefficients.
|
|
||||||
/// If c_0, c_1, ... are the coefficients in the order corresponding to
|
|
||||||
/// dimensions, symbols, and the constant term, create the affine expression:
|
|
||||||
/// expr = c_0*d0 + c_1*d1 + ... + c_{ndims-1}*d_{ndims-1} + c_{..}*s0 +
|
|
||||||
/// c_{..}*s1 + ... + const
|
|
||||||
AffineExpr *getAddMulPureAffineExpr(unsigned numDims, unsigned numSymbols,
|
|
||||||
ArrayRef<int64_t> coeffs);
|
|
||||||
|
|
||||||
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
|
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> results,
|
ArrayRef<AffineExprRef> results,
|
||||||
ArrayRef<AffineExpr *> rangeSizes);
|
ArrayRef<AffineExprRef> rangeSizes);
|
||||||
|
|
||||||
// Special cases of affine maps and integer sets
|
// Special cases of affine maps and integer sets
|
||||||
/// Returns a single constant result affine map with 0 dimensions and 0
|
/// Returns a single constant result affine map with 0 dimensions and 0
|
||||||
|
@ -153,7 +145,7 @@ public:
|
||||||
|
|
||||||
// Integer set.
|
// Integer set.
|
||||||
IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount,
|
IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> constraints,
|
ArrayRef<AffineExprRef> constraints,
|
||||||
ArrayRef<bool> isEq);
|
ArrayRef<bool> isEq);
|
||||||
// TODO: Helpers for affine map/exprs, etc.
|
// TODO: Helpers for affine map/exprs, etc.
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -47,51 +47,45 @@ class MLIRContext;
|
||||||
class IntegerSet {
|
class IntegerSet {
|
||||||
public:
|
public:
|
||||||
static IntegerSet *get(unsigned dimCount, unsigned symbolCount,
|
static IntegerSet *get(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> constraints,
|
ArrayRef<AffineExprRef> constraints,
|
||||||
ArrayRef<bool> eqFlags, MLIRContext *context);
|
ArrayRef<bool> eqFlags, MLIRContext *context);
|
||||||
|
|
||||||
unsigned getNumDims() const { return dimCount; }
|
unsigned getNumDims() { return dimCount; }
|
||||||
unsigned getNumSymbols() const { return symbolCount; }
|
unsigned getNumSymbols() { return symbolCount; }
|
||||||
unsigned getNumOperands() const { return dimCount + symbolCount; }
|
unsigned getNumOperands() { return dimCount + symbolCount; }
|
||||||
unsigned getNumConstraints() const { return numConstraints; }
|
unsigned getNumConstraints() { return numConstraints; }
|
||||||
|
|
||||||
ArrayRef<AffineExpr *> getConstraints() const {
|
ArrayRef<AffineExprRef> getConstraints() { return constraints; }
|
||||||
return ArrayRef<AffineExpr *>(constraints, numConstraints);
|
|
||||||
}
|
|
||||||
|
|
||||||
AffineExpr *getConstraint(unsigned idx) const {
|
AffineExprRef getConstraint(unsigned idx) { return getConstraints()[idx]; }
|
||||||
return getConstraints()[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the equality bits, which specify whether each of the constraints
|
/// Returns the equality bits, which specify whether each of the constraints
|
||||||
/// is an equality or inequality.
|
/// is an equality or inequality.
|
||||||
ArrayRef<bool> getEqFlags() const {
|
ArrayRef<bool> getEqFlags() { return eqFlags; }
|
||||||
return ArrayRef<bool>(eqFlags, numConstraints);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if the idx^th constraint is an equality, false if it is an
|
/// Returns true if the idx^th constraint is an equality, false if it is an
|
||||||
/// inequality.
|
/// inequality.
|
||||||
bool isEq(unsigned idx) const { return getEqFlags()[idx]; }
|
bool isEq(unsigned idx) { return getEqFlags()[idx]; }
|
||||||
|
|
||||||
void print(raw_ostream &os) const;
|
void print(raw_ostream &os);
|
||||||
void dump() const;
|
void dump();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
IntegerSet(unsigned dimCount, unsigned symbolCount, unsigned numConstraints,
|
IntegerSet(unsigned dimCount, unsigned symbolCount, unsigned numConstraints,
|
||||||
AffineExpr *const *constraints, const bool *const eqFlags);
|
ArrayRef<AffineExprRef> constraints, ArrayRef<bool> eqFlags);
|
||||||
|
|
||||||
~IntegerSet() = delete;
|
~IntegerSet() = delete;
|
||||||
|
|
||||||
const unsigned dimCount;
|
unsigned dimCount;
|
||||||
const unsigned symbolCount;
|
unsigned symbolCount;
|
||||||
const unsigned numConstraints;
|
unsigned numConstraints;
|
||||||
|
|
||||||
/// Array of affine constraints: a constaint is either an equality
|
/// Array of affine constraints: a constaint is either an equality
|
||||||
/// (affine_expr == 0) or an inequality (affine_expr >= 0).
|
/// (affine_expr == 0) or an inequality (affine_expr >= 0).
|
||||||
AffineExpr *const *const constraints;
|
ArrayRef<AffineExprRef> constraints;
|
||||||
|
|
||||||
// Bits to check whether a constraint is an equality or an inequality.
|
// Bits to check whether a constraint is an equality or an inequality.
|
||||||
const bool *const eqFlags;
|
ArrayRef<bool> eqFlags;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
|
@ -28,8 +28,10 @@
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class AffineMap;
|
|
||||||
class AffineExpr;
|
class AffineExpr;
|
||||||
|
template <typename T> class AffineExprBaseRef;
|
||||||
|
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
|
||||||
|
class AffineMap;
|
||||||
class Builder;
|
class Builder;
|
||||||
class Function;
|
class Function;
|
||||||
|
|
||||||
|
@ -68,8 +70,8 @@ public:
|
||||||
virtual void printType(const Type *type) = 0;
|
virtual void printType(const Type *type) = 0;
|
||||||
virtual void printFunctionReference(const Function *func) = 0;
|
virtual void printFunctionReference(const Function *func) = 0;
|
||||||
virtual void printAttribute(const Attribute *attr) = 0;
|
virtual void printAttribute(const Attribute *attr) = 0;
|
||||||
virtual void printAffineMap(const AffineMap *map) = 0;
|
virtual void printAffineMap(AffineMap *map) = 0;
|
||||||
virtual void printAffineExpr(const AffineExpr *expr) = 0;
|
virtual void printAffineExpr(AffineExprRef expr) = 0;
|
||||||
|
|
||||||
/// If the specified operation has attributes, print out an attribute
|
/// If the specified operation has attributes, print out an attribute
|
||||||
/// dictionary with their values. elidedAttrs allows the client to ignore
|
/// dictionary with their values. elidedAttrs allows the client to ignore
|
||||||
|
@ -104,7 +106,7 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Attribute &attr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const AffineMap &map) {
|
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const AffineMap &map) {
|
||||||
p.printAffineMap(&map);
|
p.printAffineMap(&const_cast<AffineMap &>(map));
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,14 +30,14 @@ using namespace mlir;
|
||||||
MutableAffineMap::MutableAffineMap(AffineMap *map, MLIRContext *context)
|
MutableAffineMap::MutableAffineMap(AffineMap *map, MLIRContext *context)
|
||||||
: numDims(map->getNumDims()), numSymbols(map->getNumSymbols()),
|
: numDims(map->getNumDims()), numSymbols(map->getNumSymbols()),
|
||||||
context(context) {
|
context(context) {
|
||||||
for (auto *result : map->getResults())
|
for (auto result : map->getResults())
|
||||||
results.push_back(result);
|
results.push_back(result);
|
||||||
for (auto *rangeSize : map->getRangeSizes())
|
for (auto rangeSize : map->getRangeSizes())
|
||||||
results.push_back(rangeSize);
|
results.push_back(rangeSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
|
bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
|
||||||
if (results[idx]->isMultipleOf(factor))
|
if (const_cast<AffineExprRef &>(results[idx])->isMultipleOf(factor))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
// TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to
|
// TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to
|
||||||
|
|
|
@ -38,7 +38,8 @@ getReducedConstBound(const HyperRectangularSet &set, unsigned *idx,
|
||||||
unsigned j = 0;
|
unsigned j = 0;
|
||||||
AffineBoundExprList::const_iterator it, e;
|
AffineBoundExprList::const_iterator it, e;
|
||||||
for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
|
for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
|
||||||
if (auto *cExpr = dyn_cast<AffineConstantExpr>(*it)) {
|
if (auto *cExpr = const_cast<AffineConstantExpr *>(
|
||||||
|
dyn_cast<AffineConstantExpr>(*it))) {
|
||||||
if (val == None) {
|
if (val == None) {
|
||||||
val = cExpr->getValue();
|
val = cExpr->getValue();
|
||||||
*idx = j;
|
*idx = j;
|
||||||
|
|
|
@ -85,7 +85,7 @@ AffineExprRef AffineBinaryOpExpr::getMod(AffineExprRef lhs, uint64_t rhs,
|
||||||
|
|
||||||
/// Returns true if this expression is made out of only symbols and
|
/// Returns true if this expression is made out of only symbols and
|
||||||
/// constants (no dimensional identifiers).
|
/// constants (no dimensional identifiers).
|
||||||
bool AffineExpr::isSymbolicOrConstant() const {
|
bool AffineExpr::isSymbolicOrConstant() {
|
||||||
switch (getKind()) {
|
switch (getKind()) {
|
||||||
case Kind::Constant:
|
case Kind::Constant:
|
||||||
return true;
|
return true;
|
||||||
|
@ -108,7 +108,7 @@ bool AffineExpr::isSymbolicOrConstant() const {
|
||||||
|
|
||||||
/// Returns true if this is a pure affine expression, i.e., multiplication,
|
/// Returns true if this is a pure affine expression, i.e., multiplication,
|
||||||
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
|
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
|
||||||
bool AffineExpr::isPureAffine() const {
|
bool AffineExpr::isPureAffine() {
|
||||||
switch (getKind()) {
|
switch (getKind()) {
|
||||||
case Kind::SymbolId:
|
case Kind::SymbolId:
|
||||||
case Kind::DimId:
|
case Kind::DimId:
|
||||||
|
@ -138,7 +138,7 @@ bool AffineExpr::isPureAffine() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the greatest known integral divisor of this affine expression.
|
/// Returns the greatest known integral divisor of this affine expression.
|
||||||
uint64_t AffineExpr::getLargestKnownDivisor() const {
|
uint64_t AffineExpr::getLargestKnownDivisor() {
|
||||||
AffineBinaryOpExpr *binExpr = nullptr;
|
AffineBinaryOpExpr *binExpr = nullptr;
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
case Kind::SymbolId:
|
case Kind::SymbolId:
|
||||||
|
@ -148,7 +148,7 @@ uint64_t AffineExpr::getLargestKnownDivisor() const {
|
||||||
case Kind::Constant:
|
case Kind::Constant:
|
||||||
return std::abs(cast<AffineConstantExpr>(this)->getValue());
|
return std::abs(cast<AffineConstantExpr>(this)->getValue());
|
||||||
case Kind::Mul: {
|
case Kind::Mul: {
|
||||||
binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
|
binExpr = cast<AffineBinaryOpExpr>(this);
|
||||||
return binExpr->getLHS()->getLargestKnownDivisor() *
|
return binExpr->getLHS()->getLargestKnownDivisor() *
|
||||||
binExpr->getRHS()->getLargestKnownDivisor();
|
binExpr->getRHS()->getLargestKnownDivisor();
|
||||||
}
|
}
|
||||||
|
@ -157,7 +157,7 @@ uint64_t AffineExpr::getLargestKnownDivisor() const {
|
||||||
case Kind::FloorDiv:
|
case Kind::FloorDiv:
|
||||||
case Kind::CeilDiv:
|
case Kind::CeilDiv:
|
||||||
case Kind::Mod: {
|
case Kind::Mod: {
|
||||||
binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
|
binExpr = cast<AffineBinaryOpExpr>(this);
|
||||||
return llvm::GreatestCommonDivisor64(
|
return llvm::GreatestCommonDivisor64(
|
||||||
binExpr->getLHS()->getLargestKnownDivisor(),
|
binExpr->getLHS()->getLargestKnownDivisor(),
|
||||||
binExpr->getRHS()->getLargestKnownDivisor());
|
binExpr->getRHS()->getLargestKnownDivisor());
|
||||||
|
@ -165,7 +165,7 @@ uint64_t AffineExpr::getLargestKnownDivisor() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AffineExpr::isMultipleOf(int64_t factor) const {
|
bool AffineExpr::isMultipleOf(int64_t factor) {
|
||||||
AffineBinaryOpExpr *binExpr = nullptr;
|
AffineBinaryOpExpr *binExpr = nullptr;
|
||||||
uint64_t l, u;
|
uint64_t l, u;
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
|
@ -176,7 +176,7 @@ bool AffineExpr::isMultipleOf(int64_t factor) const {
|
||||||
case Kind::Constant:
|
case Kind::Constant:
|
||||||
return cast<AffineConstantExpr>(this)->getValue() % factor == 0;
|
return cast<AffineConstantExpr>(this)->getValue() % factor == 0;
|
||||||
case Kind::Mul: {
|
case Kind::Mul: {
|
||||||
binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
|
binExpr = cast<AffineBinaryOpExpr>(this);
|
||||||
// It's probably not worth optimizing this further (to not traverse the
|
// It's probably not worth optimizing this further (to not traverse the
|
||||||
// whole sub-tree under - it that would require a version of isMultipleOf
|
// whole sub-tree under - it that would require a version of isMultipleOf
|
||||||
// that on a 'false' return also returns the largest known divisor).
|
// that on a 'false' return also returns the largest known divisor).
|
||||||
|
@ -188,7 +188,7 @@ bool AffineExpr::isMultipleOf(int64_t factor) const {
|
||||||
case Kind::FloorDiv:
|
case Kind::FloorDiv:
|
||||||
case Kind::CeilDiv:
|
case Kind::CeilDiv:
|
||||||
case Kind::Mod: {
|
case Kind::Mod: {
|
||||||
binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
|
binExpr = cast<AffineBinaryOpExpr>(this);
|
||||||
return llvm::GreatestCommonDivisor64(
|
return llvm::GreatestCommonDivisor64(
|
||||||
binExpr->getLHS()->getLargestKnownDivisor(),
|
binExpr->getLHS()->getLargestKnownDivisor(),
|
||||||
binExpr->getRHS()->getLargestKnownDivisor()) %
|
binExpr->getRHS()->getLargestKnownDivisor()) %
|
||||||
|
@ -198,7 +198,7 @@ bool AffineExpr::isMultipleOf(int64_t factor) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MLIRContext *AffineExpr::getContext() const { return context; }
|
MLIRContext *AffineExpr::getContext() { return context; }
|
||||||
|
|
||||||
template <> AffineExprRef AffineExprRef::operator+(int64_t v) const {
|
template <> AffineExprRef AffineExprRef::operator+(int64_t v) const {
|
||||||
return AffineBinaryOpExpr::getAdd(expr, v, expr->getContext());
|
return AffineBinaryOpExpr::getAdd(expr, v, expr->getContext());
|
||||||
|
|
|
@ -23,7 +23,8 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
||||||
AffineExpr *const *results, AffineExpr *const *rangeSizes)
|
ArrayRef<AffineExprRef> results,
|
||||||
|
ArrayRef<AffineExprRef> rangeSizes)
|
||||||
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
|
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
|
||||||
results(results), rangeSizes(rangeSizes) {}
|
results(results), rangeSizes(rangeSizes) {}
|
||||||
|
|
||||||
|
@ -33,30 +34,36 @@ AffineMap *AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
|
||||||
{AffineConstantExpr::get(val, context)}, {}, context);
|
{AffineConstantExpr::get(val, context)}, {}, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AffineMap::isIdentity() const {
|
bool AffineMap::isIdentity() {
|
||||||
if (getNumDims() != getNumResults())
|
if (getNumDims() != getNumResults())
|
||||||
return false;
|
return false;
|
||||||
ArrayRef<AffineExpr *> results = getResults();
|
ArrayRef<AffineExprRef> results = getResults();
|
||||||
for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
|
for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
|
||||||
auto *expr = dyn_cast<AffineDimExpr>(results[i]);
|
auto *expr =
|
||||||
|
const_cast<AffineDimExpr *>(dyn_cast<AffineDimExpr>(results[i]));
|
||||||
if (!expr || expr->getPosition() != i)
|
if (!expr || expr->getPosition() != i)
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AffineMap::isSingleConstant() const {
|
bool AffineMap::isSingleConstant() {
|
||||||
return getNumResults() == 1 && isa<AffineConstantExpr>(getResult(0));
|
return getNumResults() == 1 && isa<AffineConstantExpr>(getResult(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t AffineMap::getSingleConstantResult() const {
|
int64_t AffineMap::getSingleConstantResult() {
|
||||||
assert(isSingleConstant() && "map must have a single constant result");
|
assert(isSingleConstant() && "map must have a single constant result");
|
||||||
return cast<AffineConstantExpr>(getResult(0))->getValue();
|
return const_cast<AffineConstantExpr *>(
|
||||||
|
cast<AffineConstantExpr>(getResult(0)))
|
||||||
|
->getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AffineExprRef AffineMap::getResult(unsigned idx) { return results[idx]; }
|
||||||
|
|
||||||
/// Simplify add expression. Return nullptr if it can't be simplified.
|
/// Simplify add expression. Return nullptr if it can't be simplified.
|
||||||
AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
|
AffineExprRef AffineBinaryOpExpr::simplifyAdd(AffineExprRef lhs,
|
||||||
MLIRContext *context) {
|
AffineExprRef rhs,
|
||||||
|
MLIRContext *context) {
|
||||||
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||||
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||||
|
|
||||||
|
@ -80,16 +87,19 @@ AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
|
||||||
return lhs;
|
return lhs;
|
||||||
}
|
}
|
||||||
// Fold successive additions like (d0 + 2) + 3 into d0 + 5.
|
// Fold successive additions like (d0 + 2) + 3 into d0 + 5.
|
||||||
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
|
auto *lBin =
|
||||||
|
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
|
||||||
if (lBin && rhsConst && lBin->getKind() == Kind::Add) {
|
if (lBin && rhsConst && lBin->getKind() == Kind::Add) {
|
||||||
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
|
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||||
|
dyn_cast<AffineConstantExpr>(lBin->getRHS())))
|
||||||
return lBin->getLHS() + (lrhs->getValue() + rhsConst->getValue());
|
return lBin->getLHS() + (lrhs->getValue() + rhsConst->getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
// When doing successive additions, bring constant to the right: turn (d0 + 2)
|
// When doing successive additions, bring constant to the right: turn (d0 + 2)
|
||||||
// + d1 into (d0 + d1) + 2.
|
// + d1 into (d0 + d1) + 2.
|
||||||
if (lBin && lBin->getKind() == Kind::Add) {
|
if (lBin && lBin->getKind() == Kind::Add) {
|
||||||
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
|
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||||
|
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
|
||||||
return lBin->getLHS() + rhs + lrhs;
|
return lBin->getLHS() + rhs + lrhs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -98,8 +108,9 @@ AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
|
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
|
||||||
AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
|
AffineExprRef AffineBinaryOpExpr::simplifyMul(AffineExprRef lhs,
|
||||||
MLIRContext *context) {
|
AffineExprRef rhs,
|
||||||
|
MLIRContext *context) {
|
||||||
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||||
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||||
|
|
||||||
|
@ -129,16 +140,19 @@ AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
|
// Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
|
||||||
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
|
auto *lBin =
|
||||||
|
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
|
||||||
if (lBin && rhsConst && lBin->getKind() == Kind::Mul) {
|
if (lBin && rhsConst && lBin->getKind() == Kind::Mul) {
|
||||||
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
|
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||||
|
dyn_cast<AffineConstantExpr>(lBin->getRHS())))
|
||||||
return lBin->getLHS() * (lrhs->getValue() * rhsConst->getValue());
|
return lBin->getLHS() * (lrhs->getValue() * rhsConst->getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
// When doing successive multiplication, bring constant to the right: turn (d0
|
// When doing successive multiplication, bring constant to the right: turn (d0
|
||||||
// * 2) * d1 into (d0 * d1) * 2.
|
// * 2) * d1 into (d0 * d1) * 2.
|
||||||
if (lBin && lBin->getKind() == Kind::Mul) {
|
if (lBin && lBin->getKind() == Kind::Mul) {
|
||||||
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
|
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||||
|
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
|
||||||
return (lBin->getLHS() * rhs) * lrhs;
|
return (lBin->getLHS() * rhs) * lrhs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -146,9 +160,9 @@ AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineExpr *AffineBinaryOpExpr::simplifyFloorDiv(AffineExpr *lhs,
|
AffineExprRef AffineBinaryOpExpr::simplifyFloorDiv(AffineExprRef lhs,
|
||||||
AffineExpr *rhs,
|
AffineExprRef rhs,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||||
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||||
|
|
||||||
|
@ -162,9 +176,11 @@ AffineExpr *AffineBinaryOpExpr::simplifyFloorDiv(AffineExpr *lhs,
|
||||||
if (rhsConst->getValue() == 1)
|
if (rhsConst->getValue() == 1)
|
||||||
return lhs;
|
return lhs;
|
||||||
|
|
||||||
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
|
auto *lBin =
|
||||||
|
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
|
||||||
if (lBin && lBin->getKind() == Kind::Mul) {
|
if (lBin && lBin->getKind() == Kind::Mul) {
|
||||||
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
|
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||||
|
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
|
||||||
// rhsConst is known to be positive if a constant.
|
// rhsConst is known to be positive if a constant.
|
||||||
if (lrhs->getValue() % rhsConst->getValue() == 0)
|
if (lrhs->getValue() % rhsConst->getValue() == 0)
|
||||||
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
|
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
|
||||||
|
@ -175,9 +191,9 @@ AffineExpr *AffineBinaryOpExpr::simplifyFloorDiv(AffineExpr *lhs,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineExpr *AffineBinaryOpExpr::simplifyCeilDiv(AffineExpr *lhs,
|
AffineExprRef AffineBinaryOpExpr::simplifyCeilDiv(AffineExprRef lhs,
|
||||||
AffineExpr *rhs,
|
AffineExprRef rhs,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||||
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||||
|
|
||||||
|
@ -191,9 +207,11 @@ AffineExpr *AffineBinaryOpExpr::simplifyCeilDiv(AffineExpr *lhs,
|
||||||
if (rhsConst->getValue() == 1)
|
if (rhsConst->getValue() == 1)
|
||||||
return lhs;
|
return lhs;
|
||||||
|
|
||||||
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
|
auto *lBin =
|
||||||
|
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
|
||||||
if (lBin && lBin->getKind() == Kind::Mul) {
|
if (lBin && lBin->getKind() == Kind::Mul) {
|
||||||
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
|
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||||
|
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
|
||||||
// rhsConst is known to be positive if a constant.
|
// rhsConst is known to be positive if a constant.
|
||||||
if (lrhs->getValue() % rhsConst->getValue() == 0)
|
if (lrhs->getValue() % rhsConst->getValue() == 0)
|
||||||
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
|
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
|
||||||
|
@ -204,8 +222,9 @@ AffineExpr *AffineBinaryOpExpr::simplifyCeilDiv(AffineExpr *lhs,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineExpr *AffineBinaryOpExpr::simplifyMod(AffineExpr *lhs, AffineExpr *rhs,
|
AffineExprRef AffineBinaryOpExpr::simplifyMod(AffineExprRef lhs,
|
||||||
MLIRContext *context) {
|
AffineExprRef rhs,
|
||||||
|
MLIRContext *context) {
|
||||||
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||||
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ public:
|
||||||
// Initializes module state, populating affine map state.
|
// Initializes module state, populating affine map state.
|
||||||
void initialize(const Module *module);
|
void initialize(const Module *module);
|
||||||
|
|
||||||
int getAffineMapId(const AffineMap *affineMap) const {
|
int getAffineMapId(AffineMap *affineMap) const {
|
||||||
auto it = affineMapIds.find(affineMap);
|
auto it = affineMapIds.find(affineMap);
|
||||||
if (it == affineMapIds.end()) {
|
if (it == affineMapIds.end()) {
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -72,9 +72,9 @@ public:
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<const AffineMap *> getAffineMapIds() const { return affineMapsById; }
|
ArrayRef<AffineMap *> getAffineMapIds() const { return affineMapsById; }
|
||||||
|
|
||||||
int getIntegerSetId(const IntegerSet *integerSet) const {
|
int getIntegerSetId(IntegerSet *integerSet) const {
|
||||||
auto it = integerSetIds.find(integerSet);
|
auto it = integerSetIds.find(integerSet);
|
||||||
if (it == integerSetIds.end()) {
|
if (it == integerSetIds.end()) {
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -82,19 +82,17 @@ public:
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<const IntegerSet *> getIntegerSetIds() const {
|
ArrayRef<IntegerSet *> getIntegerSetIds() const { return integerSetsById; }
|
||||||
return integerSetsById;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void recordAffineMapReference(const AffineMap *affineMap) {
|
void recordAffineMapReference(AffineMap *affineMap) {
|
||||||
if (affineMapIds.count(affineMap) == 0) {
|
if (affineMapIds.count(affineMap) == 0) {
|
||||||
affineMapIds[affineMap] = affineMapsById.size();
|
affineMapIds[affineMap] = affineMapsById.size();
|
||||||
affineMapsById.push_back(affineMap);
|
affineMapsById.push_back(affineMap);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void recordIntegerSetReference(const IntegerSet *integerSet) {
|
void recordIntegerSetReference(IntegerSet *integerSet) {
|
||||||
if (integerSetIds.count(integerSet) == 0) {
|
if (integerSetIds.count(integerSet) == 0) {
|
||||||
integerSetIds[integerSet] = integerSetsById.size();
|
integerSetIds[integerSet] = integerSetsById.size();
|
||||||
integerSetsById.push_back(integerSet);
|
integerSetsById.push_back(integerSet);
|
||||||
|
@ -102,7 +100,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return true if this map could be printed using the shorthand form.
|
// Return true if this map could be printed using the shorthand form.
|
||||||
static bool hasShorthandForm(const AffineMap *boundMap) {
|
static bool hasShorthandForm(AffineMap *boundMap) {
|
||||||
if (boundMap->isSingleConstant())
|
if (boundMap->isSingleConstant())
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
|
@ -126,11 +124,11 @@ private:
|
||||||
void visitAttribute(const Attribute *attr);
|
void visitAttribute(const Attribute *attr);
|
||||||
void visitOperation(const Operation *op);
|
void visitOperation(const Operation *op);
|
||||||
|
|
||||||
DenseMap<const AffineMap *, int> affineMapIds;
|
DenseMap<AffineMap *, int> affineMapIds;
|
||||||
std::vector<const AffineMap *> affineMapsById;
|
std::vector<AffineMap *> affineMapsById;
|
||||||
|
|
||||||
DenseMap<const IntegerSet *, int> integerSetIds;
|
DenseMap<IntegerSet *, int> integerSetIds;
|
||||||
std::vector<const IntegerSet *> integerSetsById;
|
std::vector<IntegerSet *> integerSetsById;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
@ -275,10 +273,10 @@ public:
|
||||||
void print(const CFGFunction *fn);
|
void print(const CFGFunction *fn);
|
||||||
void print(const MLFunction *fn);
|
void print(const MLFunction *fn);
|
||||||
|
|
||||||
void printAffineMap(const AffineMap *map);
|
void printAffineMap(AffineMap *map);
|
||||||
void printAffineExpr(const AffineExpr *expr);
|
void printAffineExpr(AffineExprRef expr);
|
||||||
void printAffineConstraint(const AffineExpr *expr, bool isEq);
|
void printAffineConstraint(AffineExprRef expr, bool isEq);
|
||||||
void printIntegerSet(const IntegerSet *set);
|
void printIntegerSet(IntegerSet *set);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
raw_ostream &os;
|
raw_ostream &os;
|
||||||
|
@ -290,9 +288,9 @@ protected:
|
||||||
ArrayRef<const char *> elidedAttrs = {});
|
ArrayRef<const char *> elidedAttrs = {});
|
||||||
void printFunctionResultType(const FunctionType *type);
|
void printFunctionResultType(const FunctionType *type);
|
||||||
void printAffineMapId(int affineMapId) const;
|
void printAffineMapId(int affineMapId) const;
|
||||||
void printAffineMapReference(const AffineMap *affineMap);
|
void printAffineMapReference(AffineMap *affineMap);
|
||||||
void printIntegerSetId(int integerSetId) const;
|
void printIntegerSetId(int integerSetId) const;
|
||||||
void printIntegerSetReference(const IntegerSet *integerSet);
|
void printIntegerSetReference(IntegerSet *integerSet);
|
||||||
|
|
||||||
/// This enum is used to represent the binding stength of the enclosing
|
/// This enum is used to represent the binding stength of the enclosing
|
||||||
/// context that an AffineExpr is being printed in, so we can intelligently
|
/// context that an AffineExpr is being printed in, so we can intelligently
|
||||||
|
@ -301,7 +299,7 @@ protected:
|
||||||
Weak, // + and -
|
Weak, // + and -
|
||||||
Strong, // All other binary operators.
|
Strong, // All other binary operators.
|
||||||
};
|
};
|
||||||
void printAffineExprInternal(const AffineExpr *expr,
|
void printAffineExprInternal(AffineExprRef expr,
|
||||||
BindingStrength enclosingTightness);
|
BindingStrength enclosingTightness);
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
@ -323,7 +321,7 @@ void ModulePrinter::printAffineMapId(int affineMapId) const {
|
||||||
os << "#map" << affineMapId;
|
os << "#map" << affineMapId;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModulePrinter::printAffineMapReference(const AffineMap *affineMap) {
|
void ModulePrinter::printAffineMapReference(AffineMap *affineMap) {
|
||||||
int mapId = state.getAffineMapId(affineMap);
|
int mapId = state.getAffineMapId(affineMap);
|
||||||
if (mapId >= 0) {
|
if (mapId >= 0) {
|
||||||
// Map will be printed at top of module so print reference to its id.
|
// Map will be printed at top of module so print reference to its id.
|
||||||
|
@ -339,7 +337,7 @@ void ModulePrinter::printIntegerSetId(int integerSetId) const {
|
||||||
os << "@@set" << integerSetId;
|
os << "@@set" << integerSetId;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModulePrinter::printIntegerSetReference(const IntegerSet *integerSet) {
|
void ModulePrinter::printIntegerSetReference(IntegerSet *integerSet) {
|
||||||
int setId;
|
int setId;
|
||||||
if ((setId = state.getIntegerSetId(integerSet)) >= 0) {
|
if ((setId = state.getIntegerSetId(integerSet)) >= 0) {
|
||||||
// The set will be printed at top of module; so print reference to its id.
|
// The set will be printed at top of module; so print reference to its id.
|
||||||
|
@ -572,12 +570,12 @@ void ModulePrinter::printType(const Type *type) {
|
||||||
// Affine expressions and maps
|
// Affine expressions and maps
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void ModulePrinter::printAffineExpr(const AffineExpr *expr) {
|
void ModulePrinter::printAffineExpr(AffineExprRef expr) {
|
||||||
printAffineExprInternal(expr, BindingStrength::Weak);
|
printAffineExprInternal(expr, BindingStrength::Weak);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModulePrinter::printAffineExprInternal(
|
void ModulePrinter::printAffineExprInternal(
|
||||||
const AffineExpr *expr, BindingStrength enclosingTightness) {
|
AffineExprRef expr, BindingStrength enclosingTightness) {
|
||||||
const char *binopSpelling = nullptr;
|
const char *binopSpelling = nullptr;
|
||||||
switch (expr->getKind()) {
|
switch (expr->getKind()) {
|
||||||
case AffineExpr::Kind::SymbolId:
|
case AffineExpr::Kind::SymbolId:
|
||||||
|
@ -628,10 +626,10 @@ void ModulePrinter::printAffineExprInternal(
|
||||||
|
|
||||||
// Pretty print addition to a product that has a negative operand as a
|
// Pretty print addition to a product that has a negative operand as a
|
||||||
// subtraction.
|
// subtraction.
|
||||||
AffineExpr *rhsExpr = binOp->getRHS();
|
AffineExprRef rhsExpr = binOp->getRHS();
|
||||||
if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(rhsExpr)) {
|
if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(rhsExpr)) {
|
||||||
if (rhs->getKind() == AffineExpr::Kind::Mul) {
|
if (rhs->getKind() == AffineExpr::Kind::Mul) {
|
||||||
AffineExpr *rrhsExpr = rhs->getRHS();
|
AffineExprRef rrhsExpr = rhs->getRHS();
|
||||||
if (auto *rrhs = dyn_cast<AffineConstantExpr>(rrhsExpr)) {
|
if (auto *rrhs = dyn_cast<AffineConstantExpr>(rrhsExpr)) {
|
||||||
if (rrhs->getValue() == -1) {
|
if (rrhs->getValue() == -1) {
|
||||||
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
|
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
|
||||||
|
@ -675,12 +673,12 @@ void ModulePrinter::printAffineExprInternal(
|
||||||
os << ')';
|
os << ')';
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModulePrinter::printAffineConstraint(const AffineExpr *expr, bool isEq) {
|
void ModulePrinter::printAffineConstraint(AffineExprRef expr, bool isEq) {
|
||||||
printAffineExprInternal(expr, BindingStrength::Weak);
|
printAffineExprInternal(expr, BindingStrength::Weak);
|
||||||
isEq ? os << " == 0" : os << " >= 0";
|
isEq ? os << " == 0" : os << " >= 0";
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModulePrinter::printAffineMap(const AffineMap *map) {
|
void ModulePrinter::printAffineMap(AffineMap *map) {
|
||||||
// Dimension identifiers.
|
// Dimension identifiers.
|
||||||
os << '(';
|
os << '(';
|
||||||
for (int i = 0; i < (int)map->getNumDims() - 1; ++i)
|
for (int i = 0; i < (int)map->getNumDims() - 1; ++i)
|
||||||
|
@ -704,7 +702,7 @@ void ModulePrinter::printAffineMap(const AffineMap *map) {
|
||||||
// Result affine expressions.
|
// Result affine expressions.
|
||||||
os << " -> (";
|
os << " -> (";
|
||||||
interleaveComma(map->getResults(),
|
interleaveComma(map->getResults(),
|
||||||
[&](AffineExpr *expr) { printAffineExpr(expr); });
|
[&](AffineExprRef expr) { printAffineExpr(expr); });
|
||||||
os << ')';
|
os << ')';
|
||||||
|
|
||||||
if (!map->isBounded()) {
|
if (!map->isBounded()) {
|
||||||
|
@ -714,11 +712,11 @@ void ModulePrinter::printAffineMap(const AffineMap *map) {
|
||||||
// Print range sizes for bounded affine maps.
|
// Print range sizes for bounded affine maps.
|
||||||
os << " size (";
|
os << " size (";
|
||||||
interleaveComma(map->getRangeSizes(),
|
interleaveComma(map->getRangeSizes(),
|
||||||
[&](AffineExpr *expr) { printAffineExpr(expr); });
|
[&](AffineExprRef expr) { printAffineExpr(expr); });
|
||||||
os << ')';
|
os << ')';
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModulePrinter::printIntegerSet(const IntegerSet *set) {
|
void ModulePrinter::printIntegerSet(IntegerSet *set) {
|
||||||
// Dimension identifiers.
|
// Dimension identifiers.
|
||||||
os << '(';
|
os << '(';
|
||||||
for (unsigned i = 1; i < set->getNumDims(); ++i)
|
for (unsigned i = 1; i < set->getNumDims(); ++i)
|
||||||
|
@ -853,13 +851,13 @@ public:
|
||||||
void printAttribute(const Attribute *attr) {
|
void printAttribute(const Attribute *attr) {
|
||||||
ModulePrinter::printAttribute(attr);
|
ModulePrinter::printAttribute(attr);
|
||||||
}
|
}
|
||||||
void printAffineMap(const AffineMap *map) {
|
void printAffineMap(AffineMap *map) {
|
||||||
return ModulePrinter::printAffineMapReference(map);
|
return ModulePrinter::printAffineMapReference(map);
|
||||||
}
|
}
|
||||||
void printIntegerSet(const IntegerSet *set) {
|
void printIntegerSet(IntegerSet *set) {
|
||||||
return ModulePrinter::printIntegerSetReference(set);
|
return ModulePrinter::printIntegerSetReference(set);
|
||||||
}
|
}
|
||||||
void printAffineExpr(const AffineExpr *expr) {
|
void printAffineExpr(AffineExprRef expr) {
|
||||||
return ModulePrinter::printAffineExpr(expr);
|
return ModulePrinter::printAffineExpr(expr);
|
||||||
}
|
}
|
||||||
void printFunctionReference(const Function *func) {
|
void printFunctionReference(const Function *func) {
|
||||||
|
@ -1433,7 +1431,7 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
|
||||||
// Therefore, short-hand parsing and printing is only supported for
|
// Therefore, short-hand parsing and printing is only supported for
|
||||||
// zero-operand constant maps and single symbol operand identity maps.
|
// zero-operand constant maps and single symbol operand identity maps.
|
||||||
if (map->getNumResults() == 1) {
|
if (map->getNumResults() == 1) {
|
||||||
AffineExpr *expr = map->getResult(0);
|
AffineExprRef expr = map->getResult(0);
|
||||||
|
|
||||||
// Print constant bound.
|
// Print constant bound.
|
||||||
if (map->getNumDims() == 0 && map->getNumSymbols() == 0) {
|
if (map->getNumDims() == 0 && map->getNumSymbols() == 0) {
|
||||||
|
@ -1498,32 +1496,32 @@ void Type::print(raw_ostream &os) const {
|
||||||
|
|
||||||
void Type::dump() const { print(llvm::errs()); }
|
void Type::dump() const { print(llvm::errs()); }
|
||||||
|
|
||||||
void AffineMap::dump() const {
|
void AffineMap::dump() {
|
||||||
print(llvm::errs());
|
print(llvm::errs());
|
||||||
llvm::errs() << "\n";
|
llvm::errs() << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void AffineExpr::dump() const {
|
void AffineExpr::dump() {
|
||||||
print(llvm::errs());
|
print(llvm::errs());
|
||||||
llvm::errs() << "\n";
|
llvm::errs() << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void IntegerSet::dump() const {
|
void IntegerSet::dump() {
|
||||||
print(llvm::errs());
|
print(llvm::errs());
|
||||||
llvm::errs() << "\n";
|
llvm::errs() << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void AffineExpr::print(raw_ostream &os) const {
|
void AffineExpr::print(raw_ostream &os) {
|
||||||
ModuleState state(/*no context is known*/ nullptr);
|
ModuleState state(/*no context is known*/ nullptr);
|
||||||
ModulePrinter(os, state).printAffineExpr(this);
|
ModulePrinter(os, state).printAffineExpr(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AffineMap::print(raw_ostream &os) const {
|
void AffineMap::print(raw_ostream &os) {
|
||||||
ModuleState state(/*no context is known*/ nullptr);
|
ModuleState state(/*no context is known*/ nullptr);
|
||||||
ModulePrinter(os, state).printAffineMap(this);
|
ModulePrinter(os, state).printAffineMap(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void IntegerSet::print(raw_ostream &os) const {
|
void IntegerSet::print(raw_ostream &os) {
|
||||||
ModuleState state(/*no context is known*/ nullptr);
|
ModuleState state(/*no context is known*/ nullptr);
|
||||||
ModulePrinter(os, state).printIntegerSet(this);
|
ModulePrinter(os, state).printIntegerSet(this);
|
||||||
}
|
}
|
||||||
|
|
|
@ -150,8 +150,8 @@ FunctionAttr *Builder::getFunctionAttr(const Function *value) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
|
AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> results,
|
ArrayRef<AffineExprRef> results,
|
||||||
ArrayRef<AffineExpr *> rangeSizes) {
|
ArrayRef<AffineExprRef> rangeSizes) {
|
||||||
return AffineMap::get(dimCount, symbolCount, results, rangeSizes, context);
|
return AffineMap::get(dimCount, symbolCount, results, rangeSizes, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,7 +224,7 @@ AffineExprRef Builder::getCeilDivExpr(AffineExprRef lhs, uint64_t rhs) {
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerSet *Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
|
IntegerSet *Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> constraints,
|
ArrayRef<AffineExprRef> constraints,
|
||||||
ArrayRef<bool> isEq) {
|
ArrayRef<bool> isEq) {
|
||||||
return IntegerSet::get(dimCount, symbolCount, constraints, isEq, context);
|
return IntegerSet::get(dimCount, symbolCount, constraints, isEq, context);
|
||||||
}
|
}
|
||||||
|
@ -251,9 +251,9 @@ AffineMap *Builder::getSingleDimShiftAffineMap(int64_t shift) {
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineMap *Builder::getShiftedAffineMap(AffineMap *map, int64_t shift) {
|
AffineMap *Builder::getShiftedAffineMap(AffineMap *map, int64_t shift) {
|
||||||
SmallVector<AffineExpr *, 4> shiftedResults;
|
SmallVector<AffineExprRef, 4> shiftedResults;
|
||||||
shiftedResults.reserve(map->getNumResults());
|
shiftedResults.reserve(map->getNumResults());
|
||||||
for (auto *resultExpr : map->getResults()) {
|
for (auto resultExpr : map->getResults()) {
|
||||||
shiftedResults.push_back(getAddExpr(resultExpr, shift));
|
shiftedResults.push_back(getAddExpr(resultExpr, shift));
|
||||||
}
|
}
|
||||||
return AffineMap::get(map->getNumDims(), map->getNumSymbols(), shiftedResults,
|
return AffineMap::get(map->getNumDims(), map->getNumSymbols(), shiftedResults,
|
||||||
|
|
|
@ -22,8 +22,9 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
IntegerSet::IntegerSet(unsigned dimCount, unsigned symbolCount,
|
IntegerSet::IntegerSet(unsigned dimCount, unsigned symbolCount,
|
||||||
unsigned numConstraints, AffineExpr *const *constraints,
|
unsigned numConstraints,
|
||||||
const bool *const eqFlags)
|
ArrayRef<AffineExprRef> constraints,
|
||||||
|
ArrayRef<bool> eqFlags)
|
||||||
: dimCount(dimCount), symbolCount(symbolCount),
|
: dimCount(dimCount), symbolCount(symbolCount),
|
||||||
numConstraints(numConstraints), constraints(constraints),
|
numConstraints(numConstraints), constraints(constraints),
|
||||||
eqFlags(eqFlags) {}
|
eqFlags(eqFlags) {}
|
||||||
|
|
|
@ -59,8 +59,8 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
|
||||||
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
|
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
|
||||||
// Affine maps are uniqued based on their dim/symbol counts and affine
|
// Affine maps are uniqued based on their dim/symbol counts and affine
|
||||||
// expressions.
|
// expressions.
|
||||||
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>,
|
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExprRef>,
|
||||||
ArrayRef<AffineExpr *>>;
|
ArrayRef<AffineExprRef>>;
|
||||||
using DenseMapInfo<AffineMap *>::getHashValue;
|
using DenseMapInfo<AffineMap *>::getHashValue;
|
||||||
using DenseMapInfo<AffineMap *>::isEqual;
|
using DenseMapInfo<AffineMap *>::isEqual;
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
|
||||||
hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
|
hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isEqual(const KeyTy &lhs, const AffineMap *rhs) {
|
static bool isEqual(const KeyTy &lhs, AffineMap *rhs) {
|
||||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
return false;
|
return false;
|
||||||
return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
|
return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
|
||||||
|
@ -224,7 +224,7 @@ public:
|
||||||
|
|
||||||
// Affine binary op expression uniquing. Figure out uniquing of dimensional
|
// Affine binary op expression uniquing. Figure out uniquing of dimensional
|
||||||
// or symbolic identifiers.
|
// or symbolic identifiers.
|
||||||
DenseMap<std::tuple<unsigned, AffineExpr *, AffineExpr *>, AffineExpr *>
|
DenseMap<std::tuple<unsigned, AffineExprRef, AffineExprRef>, AffineExprRef>
|
||||||
affineExprs;
|
affineExprs;
|
||||||
|
|
||||||
// Uniqui'ing of AffineDimExpr, AffineSymbolExpr's by their position.
|
// Uniqui'ing of AffineDimExpr, AffineSymbolExpr's by their position.
|
||||||
|
@ -800,8 +800,8 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
|
AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> results,
|
ArrayRef<AffineExprRef> results,
|
||||||
ArrayRef<AffineExpr *> rangeSizes,
|
ArrayRef<AffineExprRef> rangeSizes,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
// The number of results can't be zero.
|
// The number of results can't be zero.
|
||||||
assert(!results.empty());
|
assert(!results.empty());
|
||||||
|
@ -822,12 +822,12 @@ AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
|
||||||
auto *res = impl.allocator.Allocate<AffineMap>();
|
auto *res = impl.allocator.Allocate<AffineMap>();
|
||||||
|
|
||||||
// Copy the results and range sizes into the bump pointer.
|
// Copy the results and range sizes into the bump pointer.
|
||||||
results = impl.copyInto(ArrayRef<AffineExpr *>(results));
|
results = impl.copyInto(results);
|
||||||
rangeSizes = impl.copyInto(ArrayRef<AffineExpr *>(rangeSizes));
|
rangeSizes = impl.copyInto(rangeSizes);
|
||||||
|
|
||||||
// Initialize the memory using placement new.
|
// Initialize the memory using placement new.
|
||||||
new (res) AffineMap(dimCount, symbolCount, results.size(), results.data(),
|
new (res)
|
||||||
rangeSizes.empty() ? nullptr : rangeSizes.data());
|
AffineMap(dimCount, symbolCount, results.size(), results, rangeSizes);
|
||||||
|
|
||||||
// Cache and return it.
|
// Cache and return it.
|
||||||
return *existing.first = res;
|
return *existing.first = res;
|
||||||
|
@ -843,15 +843,13 @@ AffineExprRef AffineBinaryOpExpr::get(AffineExpr::Kind kind, AffineExprRef lhs,
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Check if we already have this affine expression, and return it if we do.
|
// Check if we already have this affine expression, and return it if we do.
|
||||||
AffineExpr *lhsExpr = lhs;
|
auto keyValue = std::make_tuple((unsigned)kind, lhs, rhs);
|
||||||
AffineExpr *rhsExpr = rhs;
|
|
||||||
auto keyValue = std::make_tuple((unsigned)kind, lhsExpr, rhsExpr);
|
|
||||||
auto cached = impl.affineExprs.find(keyValue);
|
auto cached = impl.affineExprs.find(keyValue);
|
||||||
if (cached != impl.affineExprs.end())
|
if (cached != impl.affineExprs.end())
|
||||||
return cached->second;
|
return cached->second;
|
||||||
|
|
||||||
// Simplify the expression if possible.
|
// Simplify the expression if possible.
|
||||||
AffineExpr *simplified;
|
AffineExprRef simplified(nullptr);
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
case Kind::Add:
|
case Kind::Add:
|
||||||
simplified = AffineBinaryOpExpr::simplifyAdd(lhs, rhs, context);
|
simplified = AffineBinaryOpExpr::simplifyAdd(lhs, rhs, context);
|
||||||
|
@ -940,7 +938,7 @@ AffineExprRef AffineConstantExpr::get(int64_t constant, MLIRContext *context) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
IntegerSet *IntegerSet::get(unsigned dimCount, unsigned symbolCount,
|
IntegerSet *IntegerSet::get(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> constraints,
|
ArrayRef<AffineExprRef> constraints,
|
||||||
ArrayRef<bool> eqFlags, MLIRContext *context) {
|
ArrayRef<bool> eqFlags, MLIRContext *context) {
|
||||||
assert(eqFlags.size() == constraints.size());
|
assert(eqFlags.size() == constraints.size());
|
||||||
|
|
||||||
|
@ -950,10 +948,10 @@ IntegerSet *IntegerSet::get(unsigned dimCount, unsigned symbolCount,
|
||||||
auto *res = impl.allocator.Allocate<IntegerSet>();
|
auto *res = impl.allocator.Allocate<IntegerSet>();
|
||||||
|
|
||||||
// Copy the equalities and inequalities into the bump pointer.
|
// Copy the equalities and inequalities into the bump pointer.
|
||||||
constraints = impl.copyInto(ArrayRef<AffineExpr *>(constraints));
|
constraints = impl.copyInto(ArrayRef<AffineExprRef>(constraints));
|
||||||
eqFlags = impl.copyInto(ArrayRef<bool>(eqFlags));
|
eqFlags = impl.copyInto(ArrayRef<bool>(eqFlags));
|
||||||
|
|
||||||
// Initialize the memory using placement new.
|
// Initialize the memory using placement new.
|
||||||
return new (res) IntegerSet(dimCount, symbolCount, constraints.size(),
|
return new (res) IntegerSet(dimCount, symbolCount, constraints.size(),
|
||||||
constraints.data(), eqFlags.data());
|
constraints, eqFlags);
|
||||||
}
|
}
|
||||||
|
|
|
@ -193,7 +193,7 @@ public:
|
||||||
MLIRContext *context)
|
MLIRContext *context)
|
||||||
: numDims(numDims), operandConsts(operandConsts), context(context) {}
|
: numDims(numDims), operandConsts(operandConsts), context(context) {}
|
||||||
|
|
||||||
IntegerAttr *constantFold(AffineExpr *expr) {
|
IntegerAttr *constantFold(AffineExprRef expr) {
|
||||||
switch (expr->getKind()) {
|
switch (expr->getKind()) {
|
||||||
case AffineExpr::Kind::Add:
|
case AffineExpr::Kind::Add:
|
||||||
return constantFoldBinExpr(
|
return constantFoldBinExpr(
|
||||||
|
@ -224,7 +224,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
IntegerAttr *
|
IntegerAttr *
|
||||||
constantFoldBinExpr(AffineExpr *expr,
|
constantFoldBinExpr(AffineExprRef expr,
|
||||||
std::function<uint64_t(int64_t, uint64_t)> op) {
|
std::function<uint64_t(int64_t, uint64_t)> op) {
|
||||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||||
auto *lhs = constantFold(binOpExpr->getLHS());
|
auto *lhs = constantFold(binOpExpr->getLHS());
|
||||||
|
@ -254,7 +254,7 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operands,
|
||||||
AffineExprConstantFolder exprFolder(map->getNumDims(), operands, context);
|
AffineExprConstantFolder exprFolder(map->getNumDims(), operands, context);
|
||||||
|
|
||||||
// Constant fold each AffineExpr in AffineMap and add to 'results'.
|
// Constant fold each AffineExpr in AffineMap and add to 'results'.
|
||||||
for (auto *expr : map->getResults()) {
|
for (auto expr : map->getResults()) {
|
||||||
results.push_back(exprFolder.constantFold(expr));
|
results.push_back(exprFolder.constantFold(expr));
|
||||||
}
|
}
|
||||||
// Return false on success.
|
// Return false on success.
|
||||||
|
|
|
@ -1238,7 +1238,7 @@ AffineMap *AffineParser::parseAffineMapInline() {
|
||||||
parseToken(Token::l_paren, "expected '(' at start of affine map range"))
|
parseToken(Token::l_paren, "expected '(' at start of affine map range"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
SmallVector<AffineExpr *, 4> exprs;
|
SmallVector<AffineExprRef, 4> exprs;
|
||||||
auto parseElt = [&]() -> ParseResult {
|
auto parseElt = [&]() -> ParseResult {
|
||||||
auto *elt = parseAffineExpr();
|
auto *elt = parseAffineExpr();
|
||||||
ParseResult res = elt ? ParseSuccess : ParseFailure;
|
ParseResult res = elt ? ParseSuccess : ParseFailure;
|
||||||
|
@ -1257,7 +1257,7 @@ AffineMap *AffineParser::parseAffineMapInline() {
|
||||||
// dim-size ::= affine-expr | `min` `(` affine-expr (`,` affine-expr)+ `)`
|
// dim-size ::= affine-expr | `min` `(` affine-expr (`,` affine-expr)+ `)`
|
||||||
// TODO(bondhugula): support for min of several affine expressions.
|
// TODO(bondhugula): support for min of several affine expressions.
|
||||||
// TODO: check if sizes are non-negative whenever they are constant.
|
// TODO: check if sizes are non-negative whenever they are constant.
|
||||||
SmallVector<AffineExpr *, 4> rangeSizes;
|
SmallVector<AffineExprRef, 4> rangeSizes;
|
||||||
if (consumeIf(Token::kw_size)) {
|
if (consumeIf(Token::kw_size)) {
|
||||||
// Location of the l_paren token (if it exists) for error reporting later.
|
// Location of the l_paren token (if it exists) for error reporting later.
|
||||||
auto loc = getToken().getLoc();
|
auto loc = getToken().getLoc();
|
||||||
|
@ -2500,7 +2500,7 @@ IntegerSet *AffineParser::parseIntegerSetInline() {
|
||||||
"expected '(' at start of integer set constraint list"))
|
"expected '(' at start of integer set constraint list"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
SmallVector<AffineExpr *, 4> constraints;
|
SmallVector<AffineExprRef, 4> constraints;
|
||||||
SmallVector<bool, 4> isEqs;
|
SmallVector<bool, 4> isEqs;
|
||||||
auto parseElt = [&]() -> ParseResult {
|
auto parseElt = [&]() -> ParseResult {
|
||||||
bool isEq;
|
bool isEq;
|
||||||
|
|
|
@ -52,9 +52,7 @@ FunctionPass *mlir::createSimplifyAffineExprPass() {
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineMap *MutableAffineMap::getAffineMap() {
|
AffineMap *MutableAffineMap::getAffineMap() {
|
||||||
SmallVector<AffineExpr *, 8> res(results.begin(), results.end());
|
return AffineMap::get(numDims, numSymbols, results, rangeSizes, context);
|
||||||
SmallVector<AffineExpr *, 8> sizes(rangeSizes.begin(), rangeSizes.end());
|
|
||||||
return AffineMap::get(numDims, numSymbols, res, sizes, context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PassResult SimplifyAffineExpr::runOnMLFunction(MLFunction *f) {
|
PassResult SimplifyAffineExpr::runOnMLFunction(MLFunction *f) {
|
||||||
|
|
Loading…
Reference in New Issue