forked from OSchip/llvm-project
[MLIR] Value types for AffineXXXExpr
This CL makes AffineExprRef into a value type. Notably: 1. drops llvm isa, cast, dyn_cast on pointer type and uses member functions on the value type. It may be possible to still use classof (in a followup CL) 2. AffineBaseExprRef aggressively casts constness away: if we mean the type is immutable then let's jump in with both feet; 3. Drop implicit casts to the underlying pointer type because that always results in surprising behavior and is not needed in practice once enough cleanup has been applied. The remaining negative I see is that we still need to mix operator. and operator->. There is an ugly solution that forwards the methods but that ends up duplicating the class hierarchy which I tried to avoid as much as possible. But maybe it's not that bad anymore since AffineExpr.h would still contain a single class hierarchy (the duplication would be impl detail in.cpp) PiperOrigin-RevId: 216188003
This commit is contained in:
parent
d2d89cbc19
commit
4911978f7e
|
@ -26,6 +26,121 @@
|
|||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <type_traits>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AffineExpr;
|
||||
class AffineBinaryOpExpr;
|
||||
class AffineDimExpr;
|
||||
class AffineSymbolExpr;
|
||||
class AffineConstantExpr;
|
||||
|
||||
/// Helper structure to build AffineExpr with intuitive operators in order to
|
||||
/// operate on chainable, lightweight value types instead of pointer types.
|
||||
/// This structure operates on immutable types so it freely casts constness
|
||||
/// away.
|
||||
/// TODO(ntv): Remove all redundant MLIRContext* arguments through the API
|
||||
/// TODO(ntv): Remove all uses of AffineExpr* in Parser.cpp
|
||||
/// TODO(ntv): Add extra out-of-class operators for int op AffineExprBaseRef
|
||||
/// TODO(ntv): Rename
|
||||
/// TODO(ntv): Drop const everywhere it makes sense in AffineExpr
|
||||
/// TODO(ntv): remove const comment
|
||||
/// TODO(ntv): pointer pair
|
||||
template <typename AffineExprType> class AffineExprBaseRef {
|
||||
public:
|
||||
typedef AffineExprBaseRef TemplateType;
|
||||
typedef AffineExprType ImplType;
|
||||
|
||||
AffineExprBaseRef() : expr(nullptr) {}
|
||||
/* implicit */ AffineExprBaseRef(const AffineExprType *expr)
|
||||
: expr(const_cast<AffineExprType *>(expr)) {}
|
||||
|
||||
AffineExprBaseRef(const AffineExprBaseRef &other) : expr(other.expr) {}
|
||||
AffineExprBaseRef &operator=(AffineExprBaseRef other) {
|
||||
expr = other.expr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool operator==(AffineExprBaseRef other) const { return expr == other.expr; }
|
||||
|
||||
AffineExprType *operator->() const { return expr; }
|
||||
|
||||
/* implicit */ operator AffineExprBaseRef<AffineExpr>() const {
|
||||
return const_cast<AffineExpr *>(static_cast<const AffineExpr *>(expr));
|
||||
}
|
||||
explicit operator bool() const { return expr; }
|
||||
|
||||
bool empty() const { return expr == nullptr; }
|
||||
bool operator!() const { return expr == nullptr; }
|
||||
|
||||
template <typename U> bool isa() const {
|
||||
using PtrType = typename U::ImplType;
|
||||
return llvm::isa<PtrType>(const_cast<AffineExprType *>(this->expr));
|
||||
}
|
||||
template <typename U> U dyn_cast() const {
|
||||
using PtrType = typename U::ImplType;
|
||||
return U(llvm::dyn_cast<PtrType>(const_cast<AffineExprType *>(this->expr)));
|
||||
}
|
||||
template <typename U> U cast() const {
|
||||
using PtrType = typename U::ImplType;
|
||||
return U(llvm::cast<PtrType>(const_cast<AffineExprType *>(this->expr)));
|
||||
}
|
||||
|
||||
AffineExprBaseRef operator+(int64_t v) const;
|
||||
AffineExprBaseRef operator+(AffineExprBaseRef other) const;
|
||||
AffineExprBaseRef operator-() const;
|
||||
AffineExprBaseRef operator-(int64_t v) const;
|
||||
AffineExprBaseRef operator-(AffineExprBaseRef other) const;
|
||||
AffineExprBaseRef operator*(int64_t v) const;
|
||||
AffineExprBaseRef operator*(AffineExprBaseRef other) const;
|
||||
AffineExprBaseRef floorDiv(uint64_t v) const;
|
||||
AffineExprBaseRef floorDiv(AffineExprBaseRef other) const;
|
||||
AffineExprBaseRef ceilDiv(uint64_t v) const;
|
||||
AffineExprBaseRef ceilDiv(AffineExprBaseRef other) const;
|
||||
AffineExprBaseRef operator%(uint64_t v) const;
|
||||
AffineExprBaseRef operator%(AffineExprBaseRef other) const;
|
||||
|
||||
friend ::llvm::hash_code hash_value(AffineExprBaseRef arg);
|
||||
|
||||
private:
|
||||
AffineExprType *expr;
|
||||
};
|
||||
|
||||
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
|
||||
using AffineBinaryOpExprRef = AffineExprBaseRef<AffineBinaryOpExpr>;
|
||||
using AffineDimExprRef = AffineExprBaseRef<AffineDimExpr>;
|
||||
using AffineSymbolExprRef = AffineExprBaseRef<AffineSymbolExpr>;
|
||||
using AffineConstantExprRef = AffineExprBaseRef<AffineConstantExpr>;
|
||||
|
||||
// Make AffineExprRef hashable.
|
||||
inline ::llvm::hash_code hash_value(AffineExprRef arg) {
|
||||
return ::llvm::hash_value(static_cast<AffineExpr *>(arg.expr));
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
|
||||
// AffineExprRef hash just like pointers
|
||||
template <> struct DenseMapInfo<mlir::AffineExprRef> {
|
||||
static mlir::AffineExprRef getEmptyKey() {
|
||||
auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getEmptyKey();
|
||||
return mlir::AffineExprRef(pointer);
|
||||
}
|
||||
static mlir::AffineExprRef getTombstoneKey() {
|
||||
auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getTombstoneKey();
|
||||
return mlir::AffineExprRef(pointer);
|
||||
}
|
||||
static unsigned getHashValue(mlir::AffineExprRef val) {
|
||||
return mlir::hash_value(val);
|
||||
}
|
||||
static bool isEqual(mlir::AffineExprRef LHS, mlir::AffineExprRef RHS) {
|
||||
return LHS == RHS;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
@ -99,93 +214,6 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr) {
|
|||
return os;
|
||||
}
|
||||
|
||||
/// Helper structure to build AffineExpr with intuitive operators in order to
|
||||
/// operate on chainable, lightweight value types instead of pointer types.
|
||||
/// This structure operates on immutable types so it freely casts constness
|
||||
/// away.
|
||||
/// TODO(ntv): Remove all redundant MLIRContext* arguments through the API
|
||||
/// TODO(ntv): Remove all uses of AffineExpr* in Parser.cpp
|
||||
/// TODO(ntv): Add extra out-of-class operators for int op AffineExprBaseRef
|
||||
/// TODO(ntv): Rename
|
||||
/// TODO(ntv): Drop const everywhere it makes sense in AffineExpr
|
||||
/// TODO(ntv): remove const comment
|
||||
/// TODO(ntv): pointer pair
|
||||
template <typename AffineExprType> class AffineExprBaseRef {
|
||||
public:
|
||||
/* implicit */ AffineExprBaseRef(AffineExprType *expr) : expr(expr) {}
|
||||
|
||||
AffineExprBaseRef(const AffineExprBaseRef &other) : expr(other.expr){};
|
||||
AffineExprBaseRef &operator=(AffineExprBaseRef other) {
|
||||
expr = other;
|
||||
return *this;
|
||||
};
|
||||
bool operator==(AffineExprBaseRef other) const { return expr == other.expr; };
|
||||
AffineExprType *operator->() { return expr; }
|
||||
/* implicit */ operator AffineExprType *() { return expr; }
|
||||
|
||||
bool operator!() { return expr == nullptr; }
|
||||
|
||||
AffineExprBaseRef operator+(int64_t v) const;
|
||||
AffineExprBaseRef operator+(AffineExprBaseRef other) const;
|
||||
AffineExprBaseRef operator-() const;
|
||||
AffineExprBaseRef operator-(int64_t v) const;
|
||||
AffineExprBaseRef operator-(AffineExprBaseRef other) const;
|
||||
AffineExprBaseRef operator*(int64_t v) const;
|
||||
AffineExprBaseRef operator*(AffineExprBaseRef other) const;
|
||||
AffineExprBaseRef floorDiv(uint64_t v) const;
|
||||
AffineExprBaseRef floorDiv(AffineExprBaseRef other) const;
|
||||
AffineExprBaseRef ceilDiv(uint64_t v) const;
|
||||
AffineExprBaseRef ceilDiv(AffineExprBaseRef other) const;
|
||||
AffineExprBaseRef operator%(uint64_t v) const;
|
||||
AffineExprBaseRef operator%(AffineExprBaseRef other) const;
|
||||
|
||||
private:
|
||||
AffineExprType *expr;
|
||||
};
|
||||
|
||||
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
|
||||
|
||||
inline ::llvm::hash_code hash_value(AffineExprRef arg);
|
||||
} // namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
|
||||
/// This helper structure allows classof/isa/cast/dyn_cast to operate on
|
||||
/// AffineExprBaseRef<T>.
|
||||
template <typename T> struct simplify_type<mlir::AffineExprBaseRef<T>> {
|
||||
using SimpleType = T *;
|
||||
static SimpleType getSimplifiedValue(mlir::AffineExprBaseRef<T> &input) {
|
||||
return input;
|
||||
}
|
||||
};
|
||||
|
||||
// AffineExprRef hash just like pointers
|
||||
template <> struct DenseMapInfo<mlir::AffineExprRef> {
|
||||
static mlir::AffineExprRef getEmptyKey() {
|
||||
auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getEmptyKey();
|
||||
return mlir::AffineExprRef(pointer);
|
||||
}
|
||||
static mlir::AffineExprRef getTombstoneKey() {
|
||||
auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getTombstoneKey();
|
||||
return mlir::AffineExprRef(pointer);
|
||||
}
|
||||
static unsigned getHashValue(mlir::AffineExprRef val) {
|
||||
return mlir::hash_value(val);
|
||||
}
|
||||
static bool isEqual(mlir::AffineExprRef LHS, mlir::AffineExprRef RHS) {
|
||||
return LHS == RHS;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// Make AffineExprRef hashable.
|
||||
inline ::llvm::hash_code hash_value(AffineExprRef arg) {
|
||||
return ::llvm::hash_value(static_cast<AffineExpr *>(arg));
|
||||
}
|
||||
|
||||
/// Affine binary operation expression. An affine binary operation could be an
|
||||
/// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
|
||||
/// represented through a multiply by -1 and add.) These expressions are always
|
||||
|
|
|
@ -46,7 +46,7 @@ namespace mlir {
|
|||
/// struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
|
||||
/// unsigned numDimExprs;
|
||||
/// DimExprCounter() : numDimExprs(0) {}
|
||||
/// void visitAffineDimExpr(AffineDimExpr *expr) { ++numDimExprs; }
|
||||
/// void visitAffineDimExpr(AffineDimExprRef expr) { ++numDimExprs; }
|
||||
/// };
|
||||
///
|
||||
/// And this class would be used like this:
|
||||
|
@ -83,39 +83,39 @@ public:
|
|||
"Must instantiate with a derived type of AffineExprVisitor");
|
||||
switch (expr->getKind()) {
|
||||
case AffineExpr::Kind::Add: {
|
||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
|
||||
}
|
||||
case AffineExpr::Kind::Mul: {
|
||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
|
||||
}
|
||||
case AffineExpr::Kind::Mod: {
|
||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
|
||||
}
|
||||
case AffineExpr::Kind::FloorDiv: {
|
||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExpr::Kind::CeilDiv: {
|
||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExpr::Kind::Constant:
|
||||
return static_cast<SubClass *>(this)->visitConstantExpr(
|
||||
cast<AffineConstantExpr>(expr));
|
||||
expr.cast<AffineConstantExprRef>());
|
||||
case AffineExpr::Kind::DimId:
|
||||
return static_cast<SubClass *>(this)->visitDimExpr(
|
||||
cast<AffineDimExpr>(expr));
|
||||
expr.cast<AffineDimExprRef>());
|
||||
case AffineExpr::Kind::SymbolId:
|
||||
return static_cast<SubClass *>(this)->visitSymbolExpr(
|
||||
cast<AffineSymbolExpr>(expr));
|
||||
expr.cast<AffineSymbolExprRef>());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,34 +125,34 @@ public:
|
|||
"Must instantiate with a derived type of AffineExprVisitor");
|
||||
switch (expr->getKind()) {
|
||||
case AffineExpr::Kind::Add: {
|
||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
|
||||
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
|
||||
}
|
||||
case AffineExpr::Kind::Mul: {
|
||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
|
||||
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
|
||||
}
|
||||
case AffineExpr::Kind::Mod: {
|
||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
|
||||
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
|
||||
}
|
||||
case AffineExpr::Kind::FloorDiv: {
|
||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
|
||||
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExpr::Kind::CeilDiv: {
|
||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
|
||||
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExpr::Kind::Constant:
|
||||
return static_cast<SubClass *>(this)->visitConstantExpr(
|
||||
cast<AffineConstantExpr>(expr));
|
||||
expr.cast<AffineConstantExprRef>());
|
||||
case AffineExpr::Kind::DimId:
|
||||
return static_cast<SubClass *>(this)->visitDimExpr(
|
||||
cast<AffineDimExpr>(expr));
|
||||
expr.cast<AffineDimExprRef>());
|
||||
case AffineExpr::Kind::SymbolId:
|
||||
return static_cast<SubClass *>(this)->visitSymbolExpr(
|
||||
cast<AffineSymbolExpr>(expr));
|
||||
expr.cast<AffineSymbolExprRef>());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -166,29 +166,29 @@ public:
|
|||
|
||||
// Default visit methods. Note that the default op-specific binary op visit
|
||||
// methods call the general visitAffineBinaryOpExpr visit method.
|
||||
void visitAffineBinaryOpExpr(AffineBinaryOpExpr *expr) {}
|
||||
void visitAddExpr(AffineBinaryOpExpr *expr) {
|
||||
void visitAffineBinaryOpExpr(AffineBinaryOpExprRef expr) {}
|
||||
void visitAddExpr(AffineBinaryOpExprRef expr) {
|
||||
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
}
|
||||
void visitMulExpr(AffineBinaryOpExpr *expr) {
|
||||
void visitMulExpr(AffineBinaryOpExprRef expr) {
|
||||
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
}
|
||||
void visitModExpr(AffineBinaryOpExpr *expr) {
|
||||
void visitModExpr(AffineBinaryOpExprRef expr) {
|
||||
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
}
|
||||
void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
|
||||
void visitFloorDivExpr(AffineBinaryOpExprRef expr) {
|
||||
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
}
|
||||
void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
|
||||
void visitCeilDivExpr(AffineBinaryOpExprRef expr) {
|
||||
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
}
|
||||
void visitConstantExpr(AffineConstantExpr *expr) {}
|
||||
void visitAffineDimExpr(AffineDimExpr *expr) {}
|
||||
void visitAffineSymbolExpr(AffineSymbolExpr *expr) {}
|
||||
void visitConstantExpr(AffineConstantExprRef expr) {}
|
||||
void visitAffineDimExpr(AffineDimExprRef expr) {}
|
||||
void visitAffineSymbolExpr(AffineSymbolExprRef expr) {}
|
||||
|
||||
private:
|
||||
// Walk the operands - each operand is itself walked in post order.
|
||||
void walkOperandsPostOrder(AffineBinaryOpExpr *expr) {
|
||||
void walkOperandsPostOrder(AffineBinaryOpExprRef expr) {
|
||||
walkPostOrder(expr->getLHS());
|
||||
walkPostOrder(expr->getRHS());
|
||||
}
|
||||
|
|
|
@ -139,10 +139,10 @@ public:
|
|||
operandExprStack.reserve(8);
|
||||
}
|
||||
|
||||
void visitMulExpr(AffineBinaryOpExpr *expr) {
|
||||
void visitMulExpr(AffineBinaryOpExprRef expr) {
|
||||
assert(operandExprStack.size() >= 2);
|
||||
// This is a pure affine expr; the RHS will be a constant.
|
||||
assert(isa<AffineConstantExpr>(expr->getRHS()));
|
||||
assert(expr->getRHS().isa<AffineConstantExprRef>());
|
||||
// Get the RHS constant.
|
||||
auto rhsConst = operandExprStack.back()[getConstantIndex()];
|
||||
operandExprStack.pop_back();
|
||||
|
@ -153,7 +153,7 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
void visitAddExpr(AffineBinaryOpExpr *expr) {
|
||||
void visitAddExpr(AffineBinaryOpExprRef expr) {
|
||||
assert(operandExprStack.size() >= 2);
|
||||
const auto &rhs = operandExprStack.back();
|
||||
auto &lhs = operandExprStack[operandExprStack.size() - 2];
|
||||
|
@ -166,10 +166,10 @@ public:
|
|||
operandExprStack.pop_back();
|
||||
}
|
||||
|
||||
void visitModExpr(AffineBinaryOpExpr *expr) {
|
||||
void visitModExpr(AffineBinaryOpExprRef expr) {
|
||||
assert(operandExprStack.size() >= 2);
|
||||
// This is a pure affine expr; the RHS will be a constant.
|
||||
assert(isa<AffineConstantExpr>(expr->getRHS()));
|
||||
assert(expr->getRHS().isa<AffineConstantExprRef>());
|
||||
auto rhsConst = operandExprStack.back()[getConstantIndex()];
|
||||
operandExprStack.pop_back();
|
||||
auto &lhs = operandExprStack.back();
|
||||
|
@ -195,32 +195,32 @@ public:
|
|||
AffineConstantExpr::get(rhsConst, context), context));
|
||||
lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
|
||||
}
|
||||
void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
|
||||
void visitCeilDivExpr(AffineBinaryOpExprRef expr) {
|
||||
visitDivExpr(expr, /*isCeil=*/true);
|
||||
}
|
||||
void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
|
||||
void visitFloorDivExpr(AffineBinaryOpExprRef expr) {
|
||||
visitDivExpr(expr, /*isCeil=*/false);
|
||||
}
|
||||
void visitDimExpr(AffineDimExpr *expr) {
|
||||
void visitDimExpr(AffineDimExprRef expr) {
|
||||
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
|
||||
auto &eq = operandExprStack.back();
|
||||
eq[getDimStartIndex() + expr->getPosition()] = 1;
|
||||
}
|
||||
void visitSymbolExpr(AffineSymbolExpr *expr) {
|
||||
void visitSymbolExpr(AffineSymbolExprRef expr) {
|
||||
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
|
||||
auto &eq = operandExprStack.back();
|
||||
eq[getSymbolStartIndex() + expr->getPosition()] = 1;
|
||||
}
|
||||
void visitConstantExpr(AffineConstantExpr *expr) {
|
||||
void visitConstantExpr(AffineConstantExprRef expr) {
|
||||
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
|
||||
auto &eq = operandExprStack.back();
|
||||
eq[getConstantIndex()] = expr->getValue();
|
||||
}
|
||||
|
||||
private:
|
||||
void visitDivExpr(AffineBinaryOpExpr *expr, bool isCeil) {
|
||||
void visitDivExpr(AffineBinaryOpExprRef expr, bool isCeil) {
|
||||
assert(operandExprStack.size() >= 2);
|
||||
assert(isa<AffineConstantExpr>(expr->getRHS()));
|
||||
assert(expr->getRHS().isa<AffineConstantExprRef>());
|
||||
// This is a pure affine expr; the RHS is a positive constant.
|
||||
auto rhsConst = operandExprStack.back()[getConstantIndex()];
|
||||
// TODO(bondhugula): handle division by zero at the same time the issue is
|
||||
|
|
|
@ -38,8 +38,7 @@ getReducedConstBound(const HyperRectangularSet &set, unsigned *idx,
|
|||
unsigned j = 0;
|
||||
AffineBoundExprList::const_iterator it, e;
|
||||
for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
|
||||
if (auto *cExpr = const_cast<AffineConstantExpr *>(
|
||||
dyn_cast<AffineConstantExpr>(*it))) {
|
||||
if (auto cExpr = it->dyn_cast<AffineConstantExprRef>()) {
|
||||
if (val == None) {
|
||||
val = cExpr->getValue();
|
||||
*idx = j;
|
||||
|
@ -69,7 +68,7 @@ static void mergeBounds(const HyperRectangularSet &set,
|
|||
}
|
||||
if (it == lhsList.end()) {
|
||||
// There can only be one constant affine expr in this bound list.
|
||||
if (auto cExpr = dyn_cast<AffineConstantExpr>(expr)) {
|
||||
if (auto cExpr = expr.dyn_cast<AffineConstantExprRef>()) {
|
||||
unsigned idx;
|
||||
if (lb) {
|
||||
auto cb = getReducedConstBound(
|
||||
|
|
|
@ -61,7 +61,7 @@ AffineExprRef mlir::getTripCountExpr(const ForStmt &forStmt) {
|
|||
auto loopSpanExpr = simplifyAffineExpr(
|
||||
ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()),
|
||||
std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
|
||||
auto *cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
|
||||
auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExprRef>();
|
||||
if (!cExpr)
|
||||
return AffineBinaryOpExpr::getCeilDiv(loopSpanExpr, step, context);
|
||||
loopSpan = cExpr->getValue();
|
||||
|
@ -81,7 +81,10 @@ AffineExprRef mlir::getTripCountExpr(const ForStmt &forStmt) {
|
|||
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
|
||||
auto tripCountExpr = getTripCountExpr(forStmt);
|
||||
|
||||
if (auto constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
|
||||
if (!tripCountExpr)
|
||||
return None;
|
||||
|
||||
if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExprRef>())
|
||||
return constExpr->getValue();
|
||||
|
||||
return None;
|
||||
|
@ -96,7 +99,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
|
|||
if (!tripCountExpr)
|
||||
return 1;
|
||||
|
||||
if (auto constExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
|
||||
if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExprRef>()) {
|
||||
uint64_t tripCount = constExpr->getValue();
|
||||
|
||||
// 0 iteration loops (greatest divisor is 2^64 - 1).
|
||||
|
|
|
@ -27,10 +27,10 @@ AffineBinaryOpExpr::AffineBinaryOpExpr(Kind kind, AffineExprRef lhs,
|
|||
// We verify affine op expr forms at construction time.
|
||||
switch (kind) {
|
||||
case Kind::Add:
|
||||
assert(!isa<AffineConstantExpr>(lhs));
|
||||
assert(!lhs.isa<AffineConstantExprRef>());
|
||||
break;
|
||||
case Kind::Mul:
|
||||
assert(!isa<AffineConstantExpr>(lhs));
|
||||
assert(!lhs.isa<AffineConstantExprRef>());
|
||||
assert(rhs->isSymbolicOrConstant());
|
||||
break;
|
||||
case Kind::FloorDiv:
|
||||
|
@ -124,15 +124,15 @@ bool AffineExpr::isPureAffine() {
|
|||
// possible, allowing this to merge into the next case.
|
||||
auto *op = cast<AffineBinaryOpExpr>(this);
|
||||
return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine() &&
|
||||
(isa<AffineConstantExpr>(op->getLHS()) ||
|
||||
isa<AffineConstantExpr>(op->getRHS()));
|
||||
(op->getLHS().isa<AffineConstantExprRef>() ||
|
||||
op->getRHS().isa<AffineConstantExprRef>());
|
||||
}
|
||||
case Kind::FloorDiv:
|
||||
case Kind::CeilDiv:
|
||||
case Kind::Mod: {
|
||||
auto *op = cast<AffineBinaryOpExpr>(this);
|
||||
return op->getLHS()->isPureAffine() &&
|
||||
isa<AffineConstantExpr>(op->getRHS());
|
||||
op->getRHS().isa<AffineConstantExprRef>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -214,7 +214,7 @@ template <> AffineExprRef AffineExprRef::operator*(AffineExprRef other) const {
|
|||
}
|
||||
// Unary minus, delegate to operator*.
|
||||
template <> AffineExprRef AffineExprRef::operator-() const {
|
||||
return *this * (-1);
|
||||
return AffineBinaryOpExpr::getMul(expr, -1, expr->getContext());
|
||||
}
|
||||
// Delegate to operator+.
|
||||
template <> AffineExprRef AffineExprRef::operator-(int64_t v) const {
|
||||
|
|
|
@ -55,14 +55,15 @@ public:
|
|||
return constantFoldBinExpr(
|
||||
expr, [](int64_t lhs, uint64_t rhs) { return ceilDiv(lhs, rhs); });
|
||||
case AffineExpr::Kind::Constant:
|
||||
return IntegerAttr::get(cast<AffineConstantExpr>(expr)->getValue(),
|
||||
return IntegerAttr::get(expr.cast<AffineConstantExprRef>()->getValue(),
|
||||
expr->getContext());
|
||||
case AffineExpr::Kind::DimId:
|
||||
return dyn_cast_or_null<IntegerAttr>(
|
||||
operandConsts[cast<AffineDimExpr>(expr)->getPosition()]);
|
||||
operandConsts[expr.cast<AffineDimExprRef>()->getPosition()]);
|
||||
case AffineExpr::Kind::SymbolId:
|
||||
return dyn_cast_or_null<IntegerAttr>(
|
||||
operandConsts[numDims + cast<AffineSymbolExpr>(expr)->getPosition()]);
|
||||
operandConsts[numDims +
|
||||
expr.cast<AffineSymbolExprRef>()->getPosition()]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,7 +71,7 @@ private:
|
|||
IntegerAttr *
|
||||
constantFoldBinExpr(AffineExprRef expr,
|
||||
std::function<uint64_t(int64_t, uint64_t)> op) {
|
||||
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
|
||||
auto *lhs = constantFold(binOpExpr->getLHS());
|
||||
auto *rhs = constantFold(binOpExpr->getRHS());
|
||||
if (!lhs || !rhs)
|
||||
|
@ -104,8 +105,7 @@ bool AffineMap::isIdentity() {
|
|||
return false;
|
||||
ArrayRef<AffineExprRef> results = getResults();
|
||||
for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
|
||||
auto *expr =
|
||||
const_cast<AffineDimExpr *>(dyn_cast<AffineDimExpr>(results[i]));
|
||||
auto expr = results[i].dyn_cast<AffineDimExprRef>();
|
||||
if (!expr || expr->getPosition() != i)
|
||||
return false;
|
||||
}
|
||||
|
@ -113,14 +113,12 @@ bool AffineMap::isIdentity() {
|
|||
}
|
||||
|
||||
bool AffineMap::isSingleConstant() {
|
||||
return getNumResults() == 1 && isa<AffineConstantExpr>(getResult(0));
|
||||
return getNumResults() == 1 && getResult(0).isa<AffineConstantExprRef>();
|
||||
}
|
||||
|
||||
int64_t AffineMap::getSingleConstantResult() {
|
||||
assert(isSingleConstant() && "map must have a single constant result");
|
||||
return const_cast<AffineConstantExpr *>(
|
||||
cast<AffineConstantExpr>(getResult(0)))
|
||||
->getValue();
|
||||
return getResult(0).cast<AffineConstantExprRef>()->getValue();
|
||||
}
|
||||
|
||||
AffineExprRef AffineMap::getResult(unsigned idx) { return results[idx]; }
|
||||
|
@ -129,8 +127,8 @@ AffineExprRef AffineMap::getResult(unsigned idx) { return results[idx]; }
|
|||
AffineExprRef AffineBinaryOpExpr::simplifyAdd(AffineExprRef lhs,
|
||||
AffineExprRef rhs,
|
||||
MLIRContext *context) {
|
||||
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
|
||||
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
|
||||
|
||||
// Fold if both LHS, RHS are a constant.
|
||||
if (lhsConst && rhsConst)
|
||||
|
@ -139,7 +137,7 @@ AffineExprRef AffineBinaryOpExpr::simplifyAdd(AffineExprRef lhs,
|
|||
|
||||
// Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
|
||||
// If only one of them is a symbolic expressions, make it the RHS.
|
||||
if (isa<AffineConstantExpr>(lhs) ||
|
||||
if (lhs.isa<AffineConstantExprRef>() ||
|
||||
(lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())) {
|
||||
return AffineBinaryOpExpr::getAdd(rhs, lhs, context);
|
||||
}
|
||||
|
@ -152,19 +150,16 @@ AffineExprRef AffineBinaryOpExpr::simplifyAdd(AffineExprRef lhs,
|
|||
return lhs;
|
||||
}
|
||||
// Fold successive additions like (d0 + 2) + 3 into d0 + 5.
|
||||
auto *lBin =
|
||||
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
|
||||
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
|
||||
if (lBin && rhsConst && lBin->getKind() == Kind::Add) {
|
||||
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||
dyn_cast<AffineConstantExpr>(lBin->getRHS())))
|
||||
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>())
|
||||
return lBin->getLHS() + (lrhs->getValue() + rhsConst->getValue());
|
||||
}
|
||||
|
||||
// When doing successive additions, bring constant to the right: turn (d0 + 2)
|
||||
// + d1 into (d0 + d1) + 2.
|
||||
if (lBin && lBin->getKind() == Kind::Add) {
|
||||
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
|
||||
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
|
||||
return lBin->getLHS() + rhs + lrhs;
|
||||
}
|
||||
}
|
||||
|
@ -176,8 +171,8 @@ AffineExprRef AffineBinaryOpExpr::simplifyAdd(AffineExprRef lhs,
|
|||
AffineExprRef AffineBinaryOpExpr::simplifyMul(AffineExprRef lhs,
|
||||
AffineExprRef rhs,
|
||||
MLIRContext *context) {
|
||||
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
|
||||
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
|
||||
|
||||
if (lhsConst && rhsConst)
|
||||
return AffineConstantExpr::get(lhsConst->getValue() * rhsConst->getValue(),
|
||||
|
@ -188,7 +183,7 @@ AffineExprRef AffineBinaryOpExpr::simplifyMul(AffineExprRef lhs,
|
|||
// Canonicalize the mul expression so that the constant/symbolic term is the
|
||||
// RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
|
||||
// constant. (Note that a constant is trivially symbolic).
|
||||
if (!rhs->isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
|
||||
if (!rhs->isSymbolicOrConstant() || lhs.isa<AffineConstantExprRef>()) {
|
||||
// At least one of them has to be symbolic.
|
||||
return AffineBinaryOpExpr::getMul(rhs, lhs, context);
|
||||
}
|
||||
|
@ -205,19 +200,16 @@ AffineExprRef AffineBinaryOpExpr::simplifyMul(AffineExprRef lhs,
|
|||
}
|
||||
|
||||
// Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
|
||||
auto *lBin =
|
||||
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
|
||||
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
|
||||
if (lBin && rhsConst && lBin->getKind() == Kind::Mul) {
|
||||
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||
dyn_cast<AffineConstantExpr>(lBin->getRHS())))
|
||||
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>())
|
||||
return lBin->getLHS() * (lrhs->getValue() * rhsConst->getValue());
|
||||
}
|
||||
|
||||
// When doing successive multiplication, bring constant to the right: turn (d0
|
||||
// * 2) * d1 into (d0 * d1) * 2.
|
||||
if (lBin && lBin->getKind() == Kind::Mul) {
|
||||
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
|
||||
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
|
||||
return (lBin->getLHS() * rhs) * lrhs;
|
||||
}
|
||||
}
|
||||
|
@ -228,8 +220,8 @@ AffineExprRef AffineBinaryOpExpr::simplifyMul(AffineExprRef lhs,
|
|||
AffineExprRef AffineBinaryOpExpr::simplifyFloorDiv(AffineExprRef lhs,
|
||||
AffineExprRef rhs,
|
||||
MLIRContext *context) {
|
||||
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
|
||||
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
|
||||
|
||||
if (lhsConst && rhsConst)
|
||||
return AffineConstantExpr::get(
|
||||
|
@ -241,11 +233,9 @@ AffineExprRef AffineBinaryOpExpr::simplifyFloorDiv(AffineExprRef lhs,
|
|||
if (rhsConst->getValue() == 1)
|
||||
return lhs;
|
||||
|
||||
auto *lBin =
|
||||
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
|
||||
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
|
||||
if (lBin && lBin->getKind() == Kind::Mul) {
|
||||
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
|
||||
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
|
||||
// rhsConst is known to be positive if a constant.
|
||||
if (lrhs->getValue() % rhsConst->getValue() == 0)
|
||||
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
|
||||
|
@ -259,8 +249,8 @@ AffineExprRef AffineBinaryOpExpr::simplifyFloorDiv(AffineExprRef lhs,
|
|||
AffineExprRef AffineBinaryOpExpr::simplifyCeilDiv(AffineExprRef lhs,
|
||||
AffineExprRef rhs,
|
||||
MLIRContext *context) {
|
||||
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
|
||||
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
|
||||
|
||||
if (lhsConst && rhsConst)
|
||||
return AffineConstantExpr::get(
|
||||
|
@ -272,11 +262,9 @@ AffineExprRef AffineBinaryOpExpr::simplifyCeilDiv(AffineExprRef lhs,
|
|||
if (rhsConst->getValue() == 1)
|
||||
return lhs;
|
||||
|
||||
auto *lBin =
|
||||
const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
|
||||
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
|
||||
if (lBin && lBin->getKind() == Kind::Mul) {
|
||||
if (auto *lrhs = const_cast<AffineConstantExpr *>(
|
||||
dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
|
||||
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
|
||||
// rhsConst is known to be positive if a constant.
|
||||
if (lrhs->getValue() % rhsConst->getValue() == 0)
|
||||
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
|
||||
|
@ -290,8 +278,8 @@ AffineExprRef AffineBinaryOpExpr::simplifyCeilDiv(AffineExprRef lhs,
|
|||
AffineExprRef AffineBinaryOpExpr::simplifyMod(AffineExprRef lhs,
|
||||
AffineExprRef rhs,
|
||||
MLIRContext *context) {
|
||||
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
|
||||
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
|
||||
|
||||
if (lhsConst && rhsConst)
|
||||
return AffineConstantExpr::get(
|
||||
|
|
|
@ -107,8 +107,8 @@ private:
|
|||
// Check if the affine map is single dim id or single symbol identity -
|
||||
// (i)->(i) or ()[s]->(i)
|
||||
return boundMap->getNumInputs() == 1 && boundMap->getNumResults() == 1 &&
|
||||
(isa<AffineDimExpr>(boundMap->getResult(0)) ||
|
||||
isa<AffineSymbolExpr>(boundMap->getResult(0)));
|
||||
(boundMap->getResult(0).isa<AffineDimExprRef>() ||
|
||||
boundMap->getResult(0).isa<AffineSymbolExprRef>());
|
||||
}
|
||||
|
||||
// Visit functions.
|
||||
|
@ -579,13 +579,13 @@ void ModulePrinter::printAffineExprInternal(
|
|||
const char *binopSpelling = nullptr;
|
||||
switch (expr->getKind()) {
|
||||
case AffineExpr::Kind::SymbolId:
|
||||
os << 's' << cast<AffineSymbolExpr>(expr)->getPosition();
|
||||
os << 's' << expr.cast<AffineSymbolExprRef>()->getPosition();
|
||||
return;
|
||||
case AffineExpr::Kind::DimId:
|
||||
os << 'd' << cast<AffineDimExpr>(expr)->getPosition();
|
||||
os << 'd' << expr.cast<AffineDimExprRef>()->getPosition();
|
||||
return;
|
||||
case AffineExpr::Kind::Constant:
|
||||
os << cast<AffineConstantExpr>(expr)->getValue();
|
||||
os << expr.cast<AffineConstantExprRef>()->getValue();
|
||||
return;
|
||||
case AffineExpr::Kind::Add:
|
||||
binopSpelling = " + ";
|
||||
|
@ -604,7 +604,7 @@ void ModulePrinter::printAffineExprInternal(
|
|||
break;
|
||||
}
|
||||
|
||||
auto *binOp = cast<AffineBinaryOpExpr>(expr);
|
||||
auto binOp = expr.cast<AffineBinaryOpExprRef>();
|
||||
|
||||
// Handle tightly binding binary operators.
|
||||
if (binOp->getKind() != AffineExpr::Kind::Add) {
|
||||
|
@ -627,10 +627,10 @@ void ModulePrinter::printAffineExprInternal(
|
|||
// Pretty print addition to a product that has a negative operand as a
|
||||
// subtraction.
|
||||
AffineExprRef rhsExpr = binOp->getRHS();
|
||||
if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(rhsExpr)) {
|
||||
if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExprRef>()) {
|
||||
if (rhs->getKind() == AffineExpr::Kind::Mul) {
|
||||
AffineExprRef rrhsExpr = rhs->getRHS();
|
||||
if (auto *rrhs = dyn_cast<AffineConstantExpr>(rrhsExpr)) {
|
||||
if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExprRef>()) {
|
||||
if (rrhs->getValue() == -1) {
|
||||
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
|
||||
os << " - ";
|
||||
|
@ -655,7 +655,7 @@ void ModulePrinter::printAffineExprInternal(
|
|||
}
|
||||
|
||||
// Pretty print addition to a negative number as a subtraction.
|
||||
if (auto *rhs = dyn_cast<AffineConstantExpr>(rhsExpr)) {
|
||||
if (auto rhs = rhsExpr.dyn_cast<AffineConstantExprRef>()) {
|
||||
if (rhs->getValue() < 0) {
|
||||
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
|
||||
os << " - " << -rhs->getValue();
|
||||
|
@ -1435,7 +1435,7 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
|
|||
|
||||
// Print constant bound.
|
||||
if (map->getNumDims() == 0 && map->getNumSymbols() == 0) {
|
||||
if (auto *constExpr = dyn_cast<AffineConstantExpr>(expr)) {
|
||||
if (auto constExpr = expr.dyn_cast<AffineConstantExprRef>()) {
|
||||
os << constExpr->getValue();
|
||||
return;
|
||||
}
|
||||
|
@ -1444,7 +1444,7 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
|
|||
// Print bound that consists of a single SSA symbol if the map is over a
|
||||
// single symbol.
|
||||
if (map->getNumDims() == 0 && map->getNumSymbols() == 1) {
|
||||
if (auto *symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
|
||||
if (auto symExpr = expr.dyn_cast<AffineSymbolExprRef>()) {
|
||||
printOperand(bound.getOperand(0));
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -831,37 +831,38 @@ private:
|
|||
// Identifier lists for polyhedral structures.
|
||||
ParseResult parseDimIdList(unsigned &numDims);
|
||||
ParseResult parseSymbolIdList(unsigned &numSymbols);
|
||||
ParseResult parseIdentifierDefinition(AffineExpr *idExpr);
|
||||
ParseResult parseIdentifierDefinition(AffineExprRef idExpr);
|
||||
|
||||
AffineExpr *parseAffineExpr();
|
||||
AffineExpr *parseParentheticalExpr();
|
||||
AffineExpr *parseNegateExpression(AffineExpr *lhs);
|
||||
AffineExpr *parseIntegerExpr();
|
||||
AffineExpr *parseBareIdExpr();
|
||||
AffineExprRef parseAffineExpr();
|
||||
AffineExprRef parseParentheticalExpr();
|
||||
AffineExprRef parseNegateExpression(AffineExprRef lhs);
|
||||
AffineExprRef parseIntegerExpr();
|
||||
AffineExprRef parseBareIdExpr();
|
||||
|
||||
AffineExpr *getBinaryAffineOpExpr(AffineHighPrecOp op, AffineExpr *lhs,
|
||||
AffineExpr *rhs, SMLoc opLoc);
|
||||
AffineExpr *getBinaryAffineOpExpr(AffineLowPrecOp op, AffineExpr *lhs,
|
||||
AffineExpr *rhs);
|
||||
AffineExpr *parseAffineOperandExpr(AffineExpr *lhs);
|
||||
AffineExpr *parseAffineLowPrecOpExpr(AffineExpr *llhs,
|
||||
AffineLowPrecOp llhsOp);
|
||||
AffineExpr *parseAffineHighPrecOpExpr(AffineExpr *llhs,
|
||||
AffineHighPrecOp llhsOp,
|
||||
SMLoc llhsOpLoc);
|
||||
AffineExpr *parseAffineConstraint(bool *isEq);
|
||||
AffineExprRef getBinaryAffineOpExpr(AffineHighPrecOp op, AffineExprRef lhs,
|
||||
AffineExprRef rhs, SMLoc opLoc);
|
||||
AffineExprRef getBinaryAffineOpExpr(AffineLowPrecOp op, AffineExprRef lhs,
|
||||
AffineExprRef rhs);
|
||||
AffineExprRef parseAffineOperandExpr(AffineExprRef lhs);
|
||||
AffineExprRef parseAffineLowPrecOpExpr(AffineExprRef llhs,
|
||||
AffineLowPrecOp llhsOp);
|
||||
AffineExprRef parseAffineHighPrecOpExpr(AffineExprRef llhs,
|
||||
AffineHighPrecOp llhsOp,
|
||||
SMLoc llhsOpLoc);
|
||||
AffineExprRef parseAffineConstraint(bool *isEq);
|
||||
|
||||
private:
|
||||
SmallVector<std::pair<StringRef, AffineExpr *>, 4> dimsAndSymbols;
|
||||
SmallVector<std::pair<StringRef, AffineExprRef>, 4> dimsAndSymbols;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Create an affine binary high precedence op expression (mul's, div's, mod).
|
||||
/// opLoc is the location of the op token to be used to report errors
|
||||
/// for non-conforming expressions.
|
||||
AffineExpr *AffineParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
|
||||
AffineExpr *lhs,
|
||||
AffineExpr *rhs, SMLoc opLoc) {
|
||||
AffineExprRef AffineParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
|
||||
AffineExprRef lhs,
|
||||
AffineExprRef rhs,
|
||||
SMLoc opLoc) {
|
||||
// TODO: make the error location info accurate.
|
||||
switch (op) {
|
||||
case Mul:
|
||||
|
@ -899,9 +900,9 @@ AffineExpr *AffineParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
|
|||
}
|
||||
|
||||
/// Create an affine binary low precedence op expression (add, sub).
|
||||
AffineExpr *AffineParser::getBinaryAffineOpExpr(AffineLowPrecOp op,
|
||||
AffineExpr *lhs,
|
||||
AffineExpr *rhs) {
|
||||
AffineExprRef AffineParser::getBinaryAffineOpExpr(AffineLowPrecOp op,
|
||||
AffineExprRef lhs,
|
||||
AffineExprRef rhs) {
|
||||
switch (op) {
|
||||
case AffineLowPrecOp::Add:
|
||||
return builder.getAddExpr(lhs, rhs);
|
||||
|
@ -959,10 +960,10 @@ AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
|
|||
/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
|
||||
/// null. llhsOpLoc is the location of the llhsOp token that will be used to
|
||||
/// report an error for non-conforming expressions.
|
||||
AffineExpr *AffineParser::parseAffineHighPrecOpExpr(AffineExpr *llhs,
|
||||
AffineHighPrecOp llhsOp,
|
||||
SMLoc llhsOpLoc) {
|
||||
AffineExpr *lhs = parseAffineOperandExpr(llhs);
|
||||
AffineExprRef AffineParser::parseAffineHighPrecOpExpr(AffineExprRef llhs,
|
||||
AffineHighPrecOp llhsOp,
|
||||
SMLoc llhsOpLoc) {
|
||||
AffineExprRef lhs = parseAffineOperandExpr(llhs);
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
|
||||
|
@ -970,7 +971,7 @@ AffineExpr *AffineParser::parseAffineHighPrecOpExpr(AffineExpr *llhs,
|
|||
auto opLoc = getToken().getLoc();
|
||||
if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
|
||||
if (llhs) {
|
||||
AffineExpr *expr = getBinaryAffineOpExpr(llhsOp, llhs, lhs, opLoc);
|
||||
AffineExprRef expr = getBinaryAffineOpExpr(llhsOp, llhs, lhs, opLoc);
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
return parseAffineHighPrecOpExpr(expr, op, opLoc);
|
||||
|
@ -990,13 +991,13 @@ AffineExpr *AffineParser::parseAffineHighPrecOpExpr(AffineExpr *llhs,
|
|||
/// Parse an affine expression inside parentheses.
|
||||
///
|
||||
/// affine-expr ::= `(` affine-expr `)`
|
||||
AffineExpr *AffineParser::parseParentheticalExpr() {
|
||||
AffineExprRef AffineParser::parseParentheticalExpr() {
|
||||
if (parseToken(Token::l_paren, "expected '('"))
|
||||
return nullptr;
|
||||
if (getToken().is(Token::r_paren))
|
||||
return (emitError("no expression inside parentheses"), nullptr);
|
||||
|
||||
auto *expr = parseAffineExpr();
|
||||
auto expr = parseAffineExpr();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
if (parseToken(Token::r_paren, "expected ')'"))
|
||||
|
@ -1008,11 +1009,11 @@ AffineExpr *AffineParser::parseParentheticalExpr() {
|
|||
/// Parse the negation expression.
|
||||
///
|
||||
/// affine-expr ::= `-` affine-expr
|
||||
AffineExpr *AffineParser::parseNegateExpression(AffineExpr *lhs) {
|
||||
AffineExprRef AffineParser::parseNegateExpression(AffineExprRef lhs) {
|
||||
if (parseToken(Token::minus, "expected '-'"))
|
||||
return nullptr;
|
||||
|
||||
AffineExpr *operand = parseAffineOperandExpr(lhs);
|
||||
AffineExprRef operand = parseAffineOperandExpr(lhs);
|
||||
// Since negation has the highest precedence of all ops (including high
|
||||
// precedence ops) but lower than parentheses, we are only going to use
|
||||
// parseAffineOperandExpr instead of parseAffineExpr here.
|
||||
|
@ -1027,7 +1028,7 @@ AffineExpr *AffineParser::parseNegateExpression(AffineExpr *lhs) {
|
|||
/// Parse a bare id that may appear in an affine expression.
|
||||
///
|
||||
/// affine-expr ::= bare-id
|
||||
AffineExpr *AffineParser::parseBareIdExpr() {
|
||||
AffineExprRef AffineParser::parseBareIdExpr() {
|
||||
if (getToken().isNot(Token::bare_identifier))
|
||||
return (emitError("expected bare identifier"), nullptr);
|
||||
|
||||
|
@ -1045,7 +1046,7 @@ AffineExpr *AffineParser::parseBareIdExpr() {
|
|||
/// Parse a positive integral constant appearing in an affine expression.
|
||||
///
|
||||
/// affine-expr ::= integer-literal
|
||||
AffineExpr *AffineParser::parseIntegerExpr() {
|
||||
AffineExprRef AffineParser::parseIntegerExpr() {
|
||||
auto val = getToken().getUInt64IntegerValue();
|
||||
if (!val.hasValue() || (int64_t)val.getValue() < 0)
|
||||
return (emitError("constant too large for index"), nullptr);
|
||||
|
@ -1063,7 +1064,7 @@ AffineExpr *AffineParser::parseIntegerExpr() {
|
|||
// operand expression, it's an op expression and will be parsed via
|
||||
// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and -l
|
||||
// are valid operands that will be parsed by this function.
|
||||
AffineExpr *AffineParser::parseAffineOperandExpr(AffineExpr *lhs) {
|
||||
AffineExprRef AffineParser::parseAffineOperandExpr(AffineExprRef lhs) {
|
||||
switch (getToken().getKind()) {
|
||||
case Token::bare_identifier:
|
||||
return parseBareIdExpr();
|
||||
|
@ -1113,16 +1114,16 @@ AffineExpr *AffineParser::parseAffineOperandExpr(AffineExpr *lhs) {
|
|||
/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
|
||||
/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where (e2*e3)
|
||||
/// will be parsed using parseAffineHighPrecOpExpr().
|
||||
AffineExpr *AffineParser::parseAffineLowPrecOpExpr(AffineExpr *llhs,
|
||||
AffineLowPrecOp llhsOp) {
|
||||
AffineExpr *lhs;
|
||||
AffineExprRef AffineParser::parseAffineLowPrecOpExpr(AffineExprRef llhs,
|
||||
AffineLowPrecOp llhsOp) {
|
||||
AffineExprRef lhs;
|
||||
if (!(lhs = parseAffineOperandExpr(llhs)))
|
||||
return nullptr;
|
||||
|
||||
// Found an LHS. Deal with the ops.
|
||||
if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
|
||||
if (llhs) {
|
||||
AffineExpr *sum = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
|
||||
AffineExprRef sum = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
|
||||
return parseAffineLowPrecOpExpr(sum, lOp);
|
||||
}
|
||||
// No LLHS, get RHS and form the expression.
|
||||
|
@ -1132,13 +1133,13 @@ AffineExpr *AffineParser::parseAffineLowPrecOpExpr(AffineExpr *llhs,
|
|||
if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
|
||||
// We have a higher precedence op here. Get the rhs operand for the llhs
|
||||
// through parseAffineHighPrecOpExpr.
|
||||
AffineExpr *highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
|
||||
AffineExprRef highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
|
||||
if (!highRes)
|
||||
return nullptr;
|
||||
|
||||
// If llhs is null, the product forms the first operand of the yet to be
|
||||
// found expression. If non-null, the op to associate with llhs is llhsOp.
|
||||
AffineExpr *expr =
|
||||
AffineExprRef expr =
|
||||
llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes) : highRes;
|
||||
|
||||
// Recurse for subsequent low prec op's after the affine high prec op
|
||||
|
@ -1169,14 +1170,14 @@ AffineExpr *AffineParser::parseAffineLowPrecOpExpr(AffineExpr *llhs,
|
|||
/// Additional conditions are checked depending on the production. For eg., one
|
||||
/// of the operands for `*` has to be either constant/symbolic; the second
|
||||
/// operand for floordiv, ceildiv, and mod has to be a positive integer.
|
||||
AffineExpr *AffineParser::parseAffineExpr() {
|
||||
AffineExprRef AffineParser::parseAffineExpr() {
|
||||
return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
|
||||
}
|
||||
|
||||
/// Parse a dim or symbol from the lists appearing before the actual expressions
|
||||
/// of the affine map. Update our state to store the dimensional/symbolic
|
||||
/// identifier.
|
||||
ParseResult AffineParser::parseIdentifierDefinition(AffineExpr *idExpr) {
|
||||
ParseResult AffineParser::parseIdentifierDefinition(AffineExprRef idExpr) {
|
||||
if (getToken().isNot(Token::bare_identifier))
|
||||
return emitError("expected bare identifier");
|
||||
|
||||
|
@ -1240,7 +1241,7 @@ AffineMap *AffineParser::parseAffineMapInline() {
|
|||
|
||||
SmallVector<AffineExprRef, 4> exprs;
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
auto *elt = parseAffineExpr();
|
||||
auto elt = parseAffineExpr();
|
||||
ParseResult res = elt ? ParseSuccess : ParseFailure;
|
||||
exprs.push_back(elt);
|
||||
return res;
|
||||
|
@ -1266,7 +1267,7 @@ AffineMap *AffineParser::parseAffineMapInline() {
|
|||
|
||||
auto parseRangeSize = [&]() -> ParseResult {
|
||||
auto loc = getToken().getLoc();
|
||||
auto *elt = parseAffineExpr();
|
||||
auto elt = parseAffineExpr();
|
||||
if (!elt)
|
||||
return ParseFailure;
|
||||
|
||||
|
@ -2445,8 +2446,8 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
|
|||
/// isEq is set to true if the parsed constraint is an equality, false if it is
|
||||
/// an inequality (greater than or equal).
|
||||
///
|
||||
AffineExpr *AffineParser::parseAffineConstraint(bool *isEq) {
|
||||
AffineExpr *expr = parseAffineExpr();
|
||||
AffineExprRef AffineParser::parseAffineConstraint(bool *isEq) {
|
||||
AffineExprRef expr = parseAffineExpr();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
|
||||
|
@ -2504,7 +2505,7 @@ IntegerSet *AffineParser::parseIntegerSetInline() {
|
|||
SmallVector<bool, 4> isEqs;
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
bool isEq;
|
||||
auto *elt = parseAffineConstraint(&isEq);
|
||||
auto elt = parseAffineConstraint(&isEq);
|
||||
ParseResult res = elt ? ParseSuccess : ParseFailure;
|
||||
if (elt) {
|
||||
constraints.push_back(elt);
|
||||
|
|
|
@ -53,6 +53,7 @@ bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
|
|||
ArrayRef<SSAValue *> extraIndices,
|
||||
AffineMap *indexRemap) {
|
||||
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
|
||||
(void)newMemRefRank; // unused in opt mode
|
||||
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
|
||||
(void)newMemRefRank;
|
||||
if (indexRemap) {
|
||||
|
|
Loading…
Reference in New Issue