[MLIR] Cleanup AffineExpr

This CL introduces a series of cleanups for AffineExpr value types:
1. to make it clear that the value types should be used, the pointer
AffineExpr types are put in the detail namespace. Unfortunately, since the
value type operator-> only forwards to the underlying pointer type, we
still
need to expose this in the include file for now;
2. AffineExprKind is ok to use, it thus comes out of detail and thus of
AffineExpr
3. getAffineDimExpr, getAffineSymbolExpr, getAffineConstantExpr are
similarly
extracted as free functions and their naming is mande consistent across
Builder, MLContext and AffineExpr
4. AffineBinaryOpEx::simplify functions are made into static free
functions.
In particular it is moved away from AffineMap.cpp where it does not belong
5. operator AffineExprType is made explicit
6. uses the binary operators everywhere possible
7. drops the pointer usage everywhere outside of AffineExpr.cpp,
MLIRContext.cpp and AsmPrinter.cpp

PiperOrigin-RevId: 216207212
This commit is contained in:
Nicolas Vasilache 2018-10-08 10:20:25 -07:00 committed by jpienaar
parent 4911978f7e
commit ce2edea135
19 changed files with 506 additions and 468 deletions

View File

@ -28,9 +28,13 @@
namespace mlir {
namespace detail {
class AffineExpr;
} // namespace detail
template <typename T> class AffineExprBaseRef;
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
using AffineExprRef = AffineExprBaseRef<detail::AffineExpr>;
class MLIRContext;
/// Simplify an affine expression through flattening and some amount of

View File

@ -26,9 +26,13 @@
namespace mlir {
namespace detail {
class AffineExpr;
} // namespace detail
template <typename T> class AffineExprBaseRef;
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
using AffineExprRef = AffineExprBaseRef<detail::AffineExpr>;
class ForStmt;
/// Returns the trip count of the loop as an affine expression if the latter is

View File

@ -1,4 +1,4 @@
//===- AffineMap.h - MLIR Affine Map Class ----------------------*- C++ -*-===//
//===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@ -30,22 +30,47 @@
namespace mlir {
class MLIRContext;
namespace detail {
class AffineExpr;
class AffineBinaryOpExpr;
class AffineDimExpr;
class AffineSymbolExpr;
class AffineConstantExpr;
} // namespace detail
enum class AffineExprKind {
Add,
// RHS of mul is always a constant or a symbolic expression.
Mul,
// RHS of mod is always a constant or a symbolic expression.
Mod,
// RHS of floordiv is always a constant or a symbolic expression.
FloorDiv,
// RHS of ceildiv is always a constant or a symbolic expression.
CeilDiv,
/// This is a marker for the last affine binary op. The range of binary
/// op's is expected to be this element and earlier.
LAST_AFFINE_BINARY_OP = CeilDiv,
// Constant integer.
Constant,
// Dimensional identifier.
DimId,
// Symbolic identifier.
SymbolId,
};
/// Helper structure to build AffineExpr with intuitive operators in order to
/// operate on chainable, lightweight value types instead of pointer types.
/// This structure operates on immutable types so it freely casts constness
/// away.
/// operate on chainable, lightweight, immutable value types instead of pointer
/// types.
/// TODO(ntv): Remove all redundant MLIRContext* arguments through the API
/// TODO(ntv): Remove all uses of AffineExpr* in Parser.cpp
/// TODO(ntv): Add extra out-of-class operators for int op AffineExprBaseRef
/// TODO(ntv): Rename
/// TODO(ntv): Drop const everywhere it makes sense in AffineExpr
/// TODO(ntv): remove const comment
/// TODO(ntv): pointer pair
template <typename AffineExprType> class AffineExprBaseRef {
public:
@ -64,15 +89,17 @@ public:
bool operator==(AffineExprBaseRef other) const { return expr == other.expr; }
AffineExprType *operator->() const { return expr; }
/* implicit */ operator AffineExprBaseRef<AffineExpr>() const {
return const_cast<AffineExpr *>(static_cast<const AffineExpr *>(expr));
explicit operator AffineExprType *() const {
return const_cast<AffineExprType *>(expr);
}
/* implicit */ operator AffineExprBaseRef<detail::AffineExpr>() const {
return const_cast<detail::AffineExpr *>(
static_cast<const detail::AffineExpr *>(expr));
}
explicit operator bool() const { return expr; }
bool empty() const { return expr == nullptr; }
bool operator!() const { return expr == nullptr; }
AffineExprType *operator->() const { return expr; }
template <typename U> bool isa() const {
using PtrType = typename U::ImplType;
@ -107,74 +134,30 @@ private:
AffineExprType *expr;
};
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
using AffineBinaryOpExprRef = AffineExprBaseRef<AffineBinaryOpExpr>;
using AffineDimExprRef = AffineExprBaseRef<AffineDimExpr>;
using AffineSymbolExprRef = AffineExprBaseRef<AffineSymbolExpr>;
using AffineConstantExprRef = AffineExprBaseRef<AffineConstantExpr>;
using AffineExprRef = AffineExprBaseRef<detail::AffineExpr>;
using AffineBinaryOpExprRef = AffineExprBaseRef<detail::AffineBinaryOpExpr>;
using AffineDimExprRef = AffineExprBaseRef<detail::AffineDimExpr>;
using AffineSymbolExprRef = AffineExprBaseRef<detail::AffineSymbolExpr>;
using AffineConstantExprRef = AffineExprBaseRef<detail::AffineConstantExpr>;
// Make AffineExprRef hashable.
inline ::llvm::hash_code hash_value(AffineExprRef arg) {
return ::llvm::hash_value(static_cast<AffineExpr *>(arg.expr));
return ::llvm::hash_value(static_cast<detail::AffineExpr *>(arg.expr));
}
} // namespace mlir
// These free functions allow clients of the API to not use classes in detail.
AffineExprRef getAffineDimExpr(unsigned position, MLIRContext *context);
AffineExprRef getAffineSymbolExpr(unsigned position, MLIRContext *context);
AffineExprRef getAffineConstantExpr(int64_t constant, MLIRContext *context);
namespace llvm {
// AffineExprRef hash just like pointers
template <> struct DenseMapInfo<mlir::AffineExprRef> {
static mlir::AffineExprRef getEmptyKey() {
auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getEmptyKey();
return mlir::AffineExprRef(pointer);
}
static mlir::AffineExprRef getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getTombstoneKey();
return mlir::AffineExprRef(pointer);
}
static unsigned getHashValue(mlir::AffineExprRef val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::AffineExprRef LHS, mlir::AffineExprRef RHS) {
return LHS == RHS;
}
};
} // namespace llvm
namespace mlir {
class MLIRContext;
namespace detail {
/// A one-dimensional affine expression.
/// AffineExpression's are immutable (like Type's)
class AffineExpr {
public:
enum class Kind {
Add,
// RHS of mul is always a constant or a symbolic expression.
Mul,
// RHS of mod is always a constant or a symbolic expression.
Mod,
// RHS of floordiv is always a constant or a symbolic expression.
FloorDiv,
// RHS of ceildiv is always a constant or a symbolic expression.
CeilDiv,
/// This is a marker for the last affine binary op. The range of binary
/// op's is expected to be this element and earlier.
LAST_AFFINE_BINARY_OP = CeilDiv,
// Constant integer.
Constant,
// Dimensional identifier.
DimId,
// Symbolic identifier.
SymbolId,
};
/// Return the classification for this type.
Kind getKind() { return kind; }
AffineExprKind getKind() { return kind; }
void print(raw_ostream &os);
void dump();
@ -196,7 +179,7 @@ public:
MLIRContext *getContext();
protected:
explicit AffineExpr(Kind kind, MLIRContext *context)
explicit AffineExpr(AffineExprKind kind, MLIRContext *context)
: kind(kind), context(context) {}
~AffineExpr() {}
@ -205,12 +188,12 @@ private:
void operator=(const AffineExpr &) = delete;
/// Classification of the subclass
const Kind kind;
const AffineExprKind kind;
MLIRContext *context;
};
inline raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr) {
expr.print(os);
inline raw_ostream &operator<<(raw_ostream &os, AffineExprRef &expr) {
expr->print(os);
return os;
}
@ -222,11 +205,11 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr) {
/// the op type: see checks in the constructor.
class AffineBinaryOpExpr : public AffineExpr {
public:
static AffineExprRef get(Kind kind, AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
static AffineExprRef get(AffineExprKind kind, AffineExprRef lhs,
AffineExprRef rhs, MLIRContext *context);
static AffineExprRef getAdd(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context) {
return get(AffineExpr::Kind::Add, lhs, rhs, context);
return get(AffineExprKind::Add, lhs, rhs, context);
}
static AffineExprRef getAdd(AffineExprRef expr, int64_t rhs,
MLIRContext *context);
@ -235,25 +218,25 @@ public:
static AffineExprRef getMul(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context) {
return get(AffineExpr::Kind::Mul, lhs, rhs, context);
return get(AffineExprKind::Mul, lhs, rhs, context);
}
static AffineExprRef getMul(AffineExprRef expr, int64_t rhs,
MLIRContext *context);
static AffineExprRef getFloorDiv(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context) {
return get(AffineExpr::Kind::FloorDiv, lhs, rhs, context);
return get(AffineExprKind::FloorDiv, lhs, rhs, context);
}
static AffineExprRef getFloorDiv(AffineExprRef lhs, uint64_t rhs,
MLIRContext *context);
static AffineExprRef getCeilDiv(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context) {
return get(AffineExpr::Kind::CeilDiv, lhs, rhs, context);
return get(AffineExprKind::CeilDiv, lhs, rhs, context);
}
static AffineExprRef getCeilDiv(AffineExprRef lhs, uint64_t rhs,
MLIRContext *context);
static AffineExprRef getMod(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context) {
return get(AffineExpr::Kind::Mod, lhs, rhs, context);
return get(AffineExprKind::Mod, lhs, rhs, context);
}
static AffineExprRef getMod(AffineExprRef lhs, uint64_t rhs,
MLIRContext *context);
@ -264,29 +247,18 @@ public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const AffineExpr *expr) {
return const_cast<AffineExpr *>(expr)->getKind() <=
Kind::LAST_AFFINE_BINARY_OP;
AffineExprKind::LAST_AFFINE_BINARY_OP;
}
protected:
explicit AffineBinaryOpExpr(Kind kind, AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
explicit AffineBinaryOpExpr(AffineExprKind kind, AffineExprRef lhs,
AffineExprRef rhs, MLIRContext *context);
const AffineExprRef lhs;
const AffineExprRef rhs;
private:
~AffineBinaryOpExpr() = delete;
// Simplification prior to construction of binary affine op expressions.
static AffineExprRef simplifyAdd(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
static AffineExprRef simplifyMul(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
static AffineExprRef simplifyFloorDiv(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
static AffineExprRef simplifyCeilDiv(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
static AffineExprRef simplifyMod(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context);
};
/// A dimensional identifier appearing in an affine expression.
@ -303,13 +275,16 @@ public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const AffineExpr *expr) {
return const_cast<AffineExpr *>(expr)->getKind() == Kind::DimId;
return const_cast<AffineExpr *>(expr)->getKind() == AffineExprKind::DimId;
}
friend AffineExprRef mlir::getAffineDimExpr(unsigned position,
MLIRContext *context);
private:
~AffineDimExpr() = delete;
explicit AffineDimExpr(unsigned position, MLIRContext *context)
: AffineExpr(Kind::DimId, context), position(position) {}
: AffineExpr(AffineExprKind::DimId, context), position(position) {}
/// Position of this identifier in the argument list.
unsigned position;
@ -329,13 +304,17 @@ public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const AffineExpr *expr) {
return const_cast<AffineExpr *>(expr)->getKind() == Kind::SymbolId;
return const_cast<AffineExpr *>(expr)->getKind() ==
AffineExprKind::SymbolId;
}
friend AffineExprRef mlir::getAffineSymbolExpr(unsigned position,
MLIRContext *context);
private:
~AffineSymbolExpr() = delete;
explicit AffineSymbolExpr(unsigned position, MLIRContext *context)
: AffineExpr(Kind::SymbolId, context), position(position) {}
: AffineExpr(AffineExprKind::SymbolId, context), position(position) {}
/// Position of this identifier in the symbol list.
unsigned position;
@ -351,18 +330,47 @@ public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const AffineExpr *expr) {
return const_cast<AffineExpr *>(expr)->getKind() == Kind::Constant;
return const_cast<AffineExpr *>(expr)->getKind() ==
AffineExprKind::Constant;
}
friend AffineExprRef mlir::getAffineConstantExpr(int64_t constant,
MLIRContext *context);
private:
~AffineConstantExpr() = delete;
explicit AffineConstantExpr(int64_t constant, MLIRContext *context)
: AffineExpr(Kind::Constant, context), constant(constant) {}
: AffineExpr(AffineExprKind::Constant, context), constant(constant) {}
// The constant.
int64_t constant;
};
} // end namespace mlir
} // end namespace detail
} // namespace mlir
namespace llvm {
// AffineExprRef hash just like pointers
template <> struct DenseMapInfo<mlir::AffineExprRef> {
static mlir::AffineExprRef getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::AffineExprRef(
static_cast<mlir::AffineExprRef::ImplType *>(pointer));
}
static mlir::AffineExprRef getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::AffineExprRef(
static_cast<mlir::AffineExprRef::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::AffineExprRef val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::AffineExprRef LHS, mlir::AffineExprRef RHS) {
return LHS == RHS;
}
};
} // namespace llvm
#endif // MLIR_IR_AFFINE_EXPR_H

View File

@ -82,39 +82,42 @@ public:
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr->getKind()) {
case AffineExpr::Kind::Add: {
case AffineExprKind::Add: {
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExpr::Kind::Mul: {
case AffineExprKind::Mul: {
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExpr::Kind::Mod: {
case AffineExprKind::Mod: {
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExpr::Kind::FloorDiv: {
case AffineExprKind::FloorDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExpr::Kind::CeilDiv: {
case AffineExprKind::CeilDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExpr::Kind::Constant:
case AffineExprKind::Constant:
return static_cast<SubClass *>(this)->visitConstantExpr(
expr.cast<AffineConstantExprRef>());
case AffineExpr::Kind::DimId:
case AffineExprKind::DimId:
return static_cast<SubClass *>(this)->visitDimExpr(
expr.cast<AffineDimExprRef>());
case AffineExpr::Kind::SymbolId:
case AffineExprKind::SymbolId:
return static_cast<SubClass *>(this)->visitSymbolExpr(
expr.cast<AffineSymbolExprRef>());
}
}
@ -124,34 +127,37 @@ public:
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr->getKind()) {
case AffineExpr::Kind::Add: {
case AffineExprKind::Add: {
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExpr::Kind::Mul: {
case AffineExprKind::Mul: {
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExpr::Kind::Mod: {
case AffineExprKind::Mod: {
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExpr::Kind::FloorDiv: {
case AffineExprKind::FloorDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExpr::Kind::CeilDiv: {
case AffineExprKind::CeilDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExpr::Kind::Constant:
case AffineExprKind::Constant:
return static_cast<SubClass *>(this)->visitConstantExpr(
expr.cast<AffineConstantExprRef>());
case AffineExpr::Kind::DimId:
case AffineExprKind::DimId:
return static_cast<SubClass *>(this)->visitDimExpr(
expr.cast<AffineDimExprRef>());
case AffineExpr::Kind::SymbolId:
case AffineExprKind::SymbolId:
return static_cast<SubClass *>(this)->visitSymbolExpr(
expr.cast<AffineSymbolExprRef>());
}
}

View File

@ -30,9 +30,13 @@
namespace mlir {
namespace detail {
class AffineExpr;
} // namespace detail
template <typename T> class AffineExprBaseRef;
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
using AffineExprRef = AffineExprBaseRef<detail::AffineExpr>;
class Attribute;
class MLIRContext;

View File

@ -18,12 +18,20 @@
#ifndef MLIR_IR_BUILDERS_H
#define MLIR_IR_BUILDERS_H
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Statements.h"
namespace mlir {
namespace detail {
class AffineExpr;
} // namespace detail
template <typename T> class AffineExprBaseRef;
using AffineExprRef = AffineExprBaseRef<detail::AffineExpr>;
class MLIRContext;
class Module;
class UnknownLoc;
@ -45,10 +53,6 @@ class ArrayAttr;
class FunctionAttr;
class AffineMapAttr;
class AffineMap;
class AffineExpr;
class AffineConstantExpr;
class AffineDimExpr;
class AffineSymbolExpr;
/// This class is a general helper class for creating context-global objects
/// like types, attributes, and affine expressions.
@ -104,9 +108,9 @@ public:
FunctionAttr *getFunctionAttr(const Function *value);
// Affine expressions and affine maps.
AffineExprRef getDimExpr(unsigned position);
AffineExprRef getSymbolExpr(unsigned position);
AffineExprRef getConstantExpr(int64_t constant);
AffineExprRef getAffineDimExpr(unsigned position);
AffineExprRef getAffineSymbolExpr(unsigned position);
AffineExprRef getAffineConstantExpr(int64_t constant);
AffineExprRef getAddExpr(AffineExprRef lhs, AffineExprRef rhs);
AffineExprRef getAddExpr(AffineExprRef lhs, int64_t rhs);
AffineExprRef getSubExpr(AffineExprRef lhs, AffineExprRef rhs);

View File

@ -28,9 +28,14 @@
#include "llvm/Support/raw_ostream.h"
namespace mlir {
namespace detail {
class AffineExpr;
} // namespace detail
template <typename T> class AffineExprBaseRef;
using AffineExprRef = AffineExprBaseRef<AffineExpr>;
using AffineExprRef = AffineExprBaseRef<detail::AffineExpr>;
class AffineMap;
class Builder;
class Function;

View File

@ -40,14 +40,14 @@ static AffineExprRef toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() &&
"unexpected number of local expressions");
auto expr = AffineConstantExpr::get(0, context);
auto expr = getAffineConstantExpr(0, context);
// Dimensions and symbols.
for (unsigned j = 0; j < numDims + numSymbols; j++) {
if (eq[j] == 0) {
continue;
}
auto id = j < numDims ? AffineDimExpr::get(j, context)
: AffineSymbolExpr::get(j - numDims, context);
auto id = j < numDims ? getAffineDimExpr(j, context)
: getAffineSymbolExpr(j - numDims, context);
expr = expr + id * eq[j];
}
@ -190,9 +190,9 @@ public:
// Add an existential quantifier. expr1 % expr2 is replaced by (expr1 -
// q * expr2) where q is the existential quantifier introduced.
addLocalId(AffineBinaryOpExpr::getFloorDiv(
toAffineExpr(lhs, numDims, numSymbols, localExprs, context),
AffineConstantExpr::get(rhsConst, context), context));
auto a = toAffineExpr(lhs, numDims, numSymbols, localExprs, context);
auto b = getAffineConstantExpr(rhsConst, context);
addLocalId(a.floorDiv(b));
lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
}
void visitCeilDivExpr(AffineBinaryOpExprRef expr) {
@ -249,11 +249,13 @@ private:
// the ceil/floor expr (simplified up until here). Add an existential
// quantifier to express its result, i.e., expr1 div expr2 is replaced
// by a new identifier, q.
auto divKind =
isCeil ? AffineExpr::Kind::CeilDiv : AffineExpr::Kind::FloorDiv;
addLocalId(AffineBinaryOpExpr::get(
divKind, toAffineExpr(lhs, numDims, numSymbols, localExprs, context),
AffineConstantExpr::get(denominator, context), context));
auto a = toAffineExpr(lhs, numDims, numSymbols, localExprs, context);
auto b = getAffineConstantExpr(denominator, context);
if (isCeil) {
addLocalId(a.ceilDiv(b));
} else {
addLocalId(a.floorDiv(b));
}
lhs.assign(lhs.size(), 0);
lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
}

View File

@ -37,7 +37,7 @@ MutableAffineMap::MutableAffineMap(AffineMap *map, MLIRContext *context)
}
bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
if (const_cast<AffineExprRef &>(results[idx])->isMultipleOf(factor))
if (results[idx]->isMultipleOf(factor))
return true;
// TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to

View File

@ -63,7 +63,7 @@ AffineExprRef mlir::getTripCountExpr(const ForStmt &forStmt) {
std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExprRef>();
if (!cExpr)
return AffineBinaryOpExpr::getCeilDiv(loopSpanExpr, step, context);
return loopSpanExpr.ceilDiv(step);
loopSpan = cExpr->getValue();
}
@ -71,8 +71,8 @@ AffineExprRef mlir::getTripCountExpr(const ForStmt &forStmt) {
if (loopSpan < 0)
return 0;
return AffineConstantExpr::get(static_cast<uint64_t>(ceilDiv(loopSpan, step)),
context);
return getAffineConstantExpr(static_cast<uint64_t>(ceilDiv(loopSpan, step)),
context);
}
/// Returns the trip count of the loop if it's a constant, None otherwise. This

View File

@ -20,27 +20,53 @@
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::detail;
AffineBinaryOpExpr::AffineBinaryOpExpr(Kind kind, AffineExprRef lhs,
/// Returns true if this expression is made out of only symbols and
/// constants (no dimensional identifiers).
bool AffineExpr::isSymbolicOrConstant() {
switch (getKind()) {
case AffineExprKind::Constant:
return true;
case AffineExprKind::DimId:
return false;
case AffineExprKind::SymbolId:
return true;
case AffineExprKind::Add:
case AffineExprKind::Mul:
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
auto *expr = cast<AffineBinaryOpExpr>(this);
return expr->getLHS()->isSymbolicOrConstant() &&
expr->getRHS()->isSymbolicOrConstant();
}
}
}
////////////////////////////////// Details /////////////////////////////////////
AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExprKind kind, AffineExprRef lhs,
AffineExprRef rhs, MLIRContext *context)
: AffineExpr(kind, context), lhs(lhs), rhs(rhs) {
// We verify affine op expr forms at construction time.
switch (kind) {
case Kind::Add:
case AffineExprKind::Add:
assert(!lhs.isa<AffineConstantExprRef>());
break;
case Kind::Mul:
case AffineExprKind::Mul:
assert(!lhs.isa<AffineConstantExprRef>());
assert(rhs->isSymbolicOrConstant());
assert(AffineExprRef(rhs)->isSymbolicOrConstant());
break;
case Kind::FloorDiv:
assert(rhs->isSymbolicOrConstant());
case AffineExprKind::FloorDiv:
assert(AffineExprRef(rhs)->isSymbolicOrConstant());
break;
case Kind::CeilDiv:
assert(rhs->isSymbolicOrConstant());
case AffineExprKind::CeilDiv:
assert(AffineExprRef(rhs)->isSymbolicOrConstant());
break;
case Kind::Mod:
assert(rhs->isSymbolicOrConstant());
case AffineExprKind::Mod:
assert(AffineExprRef(rhs)->isSymbolicOrConstant());
break;
default:
llvm_unreachable("unexpected binary affine expr");
@ -49,77 +75,54 @@ AffineBinaryOpExpr::AffineBinaryOpExpr(Kind kind, AffineExprRef lhs,
AffineExprRef AffineBinaryOpExpr::getSub(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context) {
return getAdd(lhs, getMul(rhs, AffineConstantExpr::get(-1, context), context),
return getAdd(lhs, getMul(rhs, getAffineConstantExpr(-1, context), context),
context);
}
AffineExprRef AffineBinaryOpExpr::getAdd(AffineExprRef expr, int64_t rhs,
MLIRContext *context) {
return get(AffineExpr::Kind::Add, expr, AffineConstantExpr::get(rhs, context),
return get(AffineExprKind::Add, expr, getAffineConstantExpr(rhs, context),
context);
}
AffineExprRef AffineBinaryOpExpr::getMul(AffineExprRef expr, int64_t rhs,
MLIRContext *context) {
return get(AffineExpr::Kind::Mul, expr, AffineConstantExpr::get(rhs, context),
return get(AffineExprKind::Mul, expr, getAffineConstantExpr(rhs, context),
context);
}
AffineExprRef AffineBinaryOpExpr::getFloorDiv(AffineExprRef lhs, uint64_t rhs,
MLIRContext *context) {
return get(AffineExpr::Kind::FloorDiv, lhs,
AffineConstantExpr::get(rhs, context), context);
return get(AffineExprKind::FloorDiv, lhs, getAffineConstantExpr(rhs, context),
context);
}
AffineExprRef AffineBinaryOpExpr::getCeilDiv(AffineExprRef lhs, uint64_t rhs,
MLIRContext *context) {
return get(AffineExpr::Kind::CeilDiv, lhs,
AffineConstantExpr::get(rhs, context), context);
return get(AffineExprKind::CeilDiv, lhs, getAffineConstantExpr(rhs, context),
context);
}
AffineExprRef AffineBinaryOpExpr::getMod(AffineExprRef lhs, uint64_t rhs,
MLIRContext *context) {
return get(AffineExpr::Kind::Mod, lhs, AffineConstantExpr::get(rhs, context),
return get(AffineExprKind::Mod, lhs, getAffineConstantExpr(rhs, context),
context);
}
/// Returns true if this expression is made out of only symbols and
/// constants (no dimensional identifiers).
bool AffineExpr::isSymbolicOrConstant() {
switch (getKind()) {
case Kind::Constant:
return true;
case Kind::DimId:
return false;
case Kind::SymbolId:
return true;
case Kind::Add:
case Kind::Mul:
case Kind::FloorDiv:
case Kind::CeilDiv:
case Kind::Mod: {
auto *expr = cast<AffineBinaryOpExpr>(this);
return expr->getLHS()->isSymbolicOrConstant() &&
expr->getRHS()->isSymbolicOrConstant();
}
}
}
/// Returns true if this is a pure affine expression, i.e., multiplication,
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
bool AffineExpr::isPureAffine() {
switch (getKind()) {
case Kind::SymbolId:
case Kind::DimId:
case Kind::Constant:
case AffineExprKind::SymbolId:
case AffineExprKind::DimId:
case AffineExprKind::Constant:
return true;
case Kind::Add: {
case AffineExprKind::Add: {
auto *op = cast<AffineBinaryOpExpr>(this);
return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine();
}
case Kind::Mul: {
case AffineExprKind::Mul: {
// TODO: Canonicalize the constants in binary operators to the RHS when
// possible, allowing this to merge into the next case.
auto *op = cast<AffineBinaryOpExpr>(this);
@ -127,9 +130,9 @@ bool AffineExpr::isPureAffine() {
(op->getLHS().isa<AffineConstantExprRef>() ||
op->getRHS().isa<AffineConstantExprRef>());
}
case Kind::FloorDiv:
case Kind::CeilDiv:
case Kind::Mod: {
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
auto *op = cast<AffineBinaryOpExpr>(this);
return op->getLHS()->isPureAffine() &&
op->getRHS().isa<AffineConstantExprRef>();
@ -139,24 +142,24 @@ bool AffineExpr::isPureAffine() {
/// Returns the greatest known integral divisor of this affine expression.
uint64_t AffineExpr::getLargestKnownDivisor() {
AffineBinaryOpExpr *binExpr = nullptr;
switch (kind) {
case Kind::SymbolId:
AffineBinaryOpExprRef binExpr;
switch (getKind()) {
case AffineExprKind::SymbolId:
LLVM_FALLTHROUGH;
case Kind::DimId:
case AffineExprKind::DimId:
return 1;
case Kind::Constant:
case AffineExprKind::Constant:
return std::abs(cast<AffineConstantExpr>(this)->getValue());
case Kind::Mul: {
case AffineExprKind::Mul: {
binExpr = cast<AffineBinaryOpExpr>(this);
return binExpr->getLHS()->getLargestKnownDivisor() *
binExpr->getRHS()->getLargestKnownDivisor();
}
case Kind::Add:
case AffineExprKind::Add:
LLVM_FALLTHROUGH;
case Kind::FloorDiv:
case Kind::CeilDiv:
case Kind::Mod: {
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
binExpr = cast<AffineBinaryOpExpr>(this);
return llvm::GreatestCommonDivisor64(
binExpr->getLHS()->getLargestKnownDivisor(),
@ -166,16 +169,16 @@ uint64_t AffineExpr::getLargestKnownDivisor() {
}
bool AffineExpr::isMultipleOf(int64_t factor) {
AffineBinaryOpExpr *binExpr = nullptr;
AffineBinaryOpExpr *binExpr;
uint64_t l, u;
switch (kind) {
case Kind::SymbolId:
switch (getKind()) {
case AffineExprKind::SymbolId:
LLVM_FALLTHROUGH;
case Kind::DimId:
case AffineExprKind::DimId:
return factor * factor == 1;
case Kind::Constant:
case AffineExprKind::Constant:
return cast<AffineConstantExpr>(this)->getValue() % factor == 0;
case Kind::Mul: {
case AffineExprKind::Mul: {
binExpr = cast<AffineBinaryOpExpr>(this);
// It's probably not worth optimizing this further (to not traverse the
// whole sub-tree under - it that would require a version of isMultipleOf
@ -184,10 +187,10 @@ bool AffineExpr::isMultipleOf(int64_t factor) {
(u = binExpr->getRHS()->getLargestKnownDivisor()) % factor == 0 ||
(l * u) % factor == 0;
}
case Kind::Add:
case Kind::FloorDiv:
case Kind::CeilDiv:
case Kind::Mod: {
case AffineExprKind::Add:
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
binExpr = cast<AffineBinaryOpExpr>(this);
return llvm::GreatestCommonDivisor64(
binExpr->getLHS()->getLargestKnownDivisor(),
@ -200,17 +203,19 @@ bool AffineExpr::isMultipleOf(int64_t factor) {
MLIRContext *AffineExpr::getContext() { return context; }
///////////////////////////// Done with details ///////////////////////////////
template <> AffineExprRef AffineExprRef::operator+(int64_t v) const {
return AffineBinaryOpExpr::getAdd(expr, v, expr->getContext());
}
template <> AffineExprRef AffineExprRef::operator+(AffineExprRef other) const {
return AffineBinaryOpExpr::getAdd(expr, other, expr->getContext());
return AffineBinaryOpExpr::getAdd(expr, other.expr, expr->getContext());
}
template <> AffineExprRef AffineExprRef::operator*(int64_t v) const {
return AffineBinaryOpExpr::getMul(expr, v, expr->getContext());
}
template <> AffineExprRef AffineExprRef::operator*(AffineExprRef other) const {
return AffineBinaryOpExpr::getMul(expr, other, expr->getContext());
return AffineBinaryOpExpr::getMul(expr, other.expr, expr->getContext());
}
// Unary minus, delegate to operator*.
template <> AffineExprRef AffineExprRef::operator-() const {
@ -227,17 +232,17 @@ template <> AffineExprRef AffineExprRef::floorDiv(uint64_t v) const {
return AffineBinaryOpExpr::getFloorDiv(expr, v, expr->getContext());
}
template <> AffineExprRef AffineExprRef::floorDiv(AffineExprRef other) const {
return AffineBinaryOpExpr::getFloorDiv(expr, other, expr->getContext());
return AffineBinaryOpExpr::getFloorDiv(expr, other.expr, expr->getContext());
}
template <> AffineExprRef AffineExprRef::ceilDiv(uint64_t v) const {
return AffineBinaryOpExpr::getCeilDiv(expr, v, expr->getContext());
}
template <> AffineExprRef AffineExprRef::ceilDiv(AffineExprRef other) const {
return AffineBinaryOpExpr::getCeilDiv(expr, other, expr->getContext());
return AffineBinaryOpExpr::getCeilDiv(expr, other.expr, expr->getContext());
}
template <> AffineExprRef AffineExprRef::operator%(uint64_t v) const {
return AffineBinaryOpExpr::getMod(expr, v, expr->getContext());
}
template <> AffineExprRef AffineExprRef::operator%(AffineExprRef other) const {
return AffineBinaryOpExpr::getMod(expr, other, expr->getContext());
return AffineBinaryOpExpr::getMod(expr, other.expr, expr->getContext());
}

View File

@ -39,28 +39,28 @@ public:
/// failure.
IntegerAttr *constantFold(AffineExprRef expr) {
switch (expr->getKind()) {
case AffineExpr::Kind::Add:
case AffineExprKind::Add:
return constantFoldBinExpr(
expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
case AffineExpr::Kind::Mul:
case AffineExprKind::Mul:
return constantFoldBinExpr(
expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
case AffineExpr::Kind::Mod:
case AffineExprKind::Mod:
return constantFoldBinExpr(
expr, [](int64_t lhs, uint64_t rhs) { return mod(lhs, rhs); });
case AffineExpr::Kind::FloorDiv:
case AffineExprKind::FloorDiv:
return constantFoldBinExpr(
expr, [](int64_t lhs, uint64_t rhs) { return floorDiv(lhs, rhs); });
case AffineExpr::Kind::CeilDiv:
case AffineExprKind::CeilDiv:
return constantFoldBinExpr(
expr, [](int64_t lhs, uint64_t rhs) { return ceilDiv(lhs, rhs); });
case AffineExpr::Kind::Constant:
case AffineExprKind::Constant:
return IntegerAttr::get(expr.cast<AffineConstantExprRef>()->getValue(),
expr->getContext());
case AffineExpr::Kind::DimId:
case AffineExprKind::DimId:
return dyn_cast_or_null<IntegerAttr>(
operandConsts[expr.cast<AffineDimExprRef>()->getPosition()]);
case AffineExpr::Kind::SymbolId:
case AffineExprKind::SymbolId:
return dyn_cast_or_null<IntegerAttr>(
operandConsts[numDims +
expr.cast<AffineSymbolExprRef>()->getPosition()]);
@ -97,7 +97,7 @@ AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
/// Returns a single constant result affine map.
AffineMap *AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
return get(/*dimCount=*/0, /*symbolCount=*/0,
{AffineConstantExpr::get(val, context)}, {}, context);
{getAffineConstantExpr(val, context)}, {}, context);
}
bool AffineMap::isIdentity() {
@ -123,184 +123,6 @@ int64_t AffineMap::getSingleConstantResult() {
AffineExprRef AffineMap::getResult(unsigned idx) { return results[idx]; }
/// Simplify add expression. Return nullptr if it can't be simplified.
AffineExprRef AffineBinaryOpExpr::simplifyAdd(AffineExprRef lhs,
AffineExprRef rhs,
MLIRContext *context) {
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
// Fold if both LHS, RHS are a constant.
if (lhsConst && rhsConst)
return AffineConstantExpr::get(lhsConst->getValue() + rhsConst->getValue(),
context);
// Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
// If only one of them is a symbolic expressions, make it the RHS.
if (lhs.isa<AffineConstantExprRef>() ||
(lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())) {
return AffineBinaryOpExpr::getAdd(rhs, lhs, context);
}
// At this point, if there was a constant, it would be on the right.
// Addition with a zero is a noop, return the other input.
if (rhsConst) {
if (rhsConst->getValue() == 0)
return lhs;
}
// Fold successive additions like (d0 + 2) + 3 into d0 + 5.
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
if (lBin && rhsConst && lBin->getKind() == Kind::Add) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>())
return lBin->getLHS() + (lrhs->getValue() + rhsConst->getValue());
}
// When doing successive additions, bring constant to the right: turn (d0 + 2)
// + d1 into (d0 + d1) + 2.
if (lBin && lBin->getKind() == Kind::Add) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
return lBin->getLHS() + rhs + lrhs;
}
}
return nullptr;
}
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
AffineExprRef AffineBinaryOpExpr::simplifyMul(AffineExprRef lhs,
AffineExprRef rhs,
MLIRContext *context) {
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
if (lhsConst && rhsConst)
return AffineConstantExpr::get(lhsConst->getValue() * rhsConst->getValue(),
context);
assert(lhs->isSymbolicOrConstant() || rhs->isSymbolicOrConstant());
// Canonicalize the mul expression so that the constant/symbolic term is the
// RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
// constant. (Note that a constant is trivially symbolic).
if (!rhs->isSymbolicOrConstant() || lhs.isa<AffineConstantExprRef>()) {
// At least one of them has to be symbolic.
return AffineBinaryOpExpr::getMul(rhs, lhs, context);
}
// At this point, if there was a constant, it would be on the right.
// Multiplication with a one is a noop, return the other input.
if (rhsConst) {
if (rhsConst->getValue() == 1)
return lhs;
// Multiplication with zero.
if (rhsConst->getValue() == 0)
return rhsConst;
}
// Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
if (lBin && rhsConst && lBin->getKind() == Kind::Mul) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>())
return lBin->getLHS() * (lrhs->getValue() * rhsConst->getValue());
}
// When doing successive multiplication, bring constant to the right: turn (d0
// * 2) * d1 into (d0 * d1) * 2.
if (lBin && lBin->getKind() == Kind::Mul) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
return (lBin->getLHS() * rhs) * lrhs;
}
}
return nullptr;
}
AffineExprRef AffineBinaryOpExpr::simplifyFloorDiv(AffineExprRef lhs,
AffineExprRef rhs,
MLIRContext *context) {
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
if (lhsConst && rhsConst)
return AffineConstantExpr::get(
floorDiv(lhsConst->getValue(), rhsConst->getValue()), context);
// Fold floordiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) floordiv 64 = i * 2.
if (rhsConst) {
if (rhsConst->getValue() == 1)
return lhs;
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
if (lBin && lBin->getKind() == Kind::Mul) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
// rhsConst is known to be positive if a constant.
if (lrhs->getValue() % rhsConst->getValue() == 0)
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
}
}
}
return nullptr;
}
AffineExprRef AffineBinaryOpExpr::simplifyCeilDiv(AffineExprRef lhs,
AffineExprRef rhs,
MLIRContext *context) {
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
if (lhsConst && rhsConst)
return AffineConstantExpr::get(
ceilDiv(lhsConst->getValue(), rhsConst->getValue()), context);
// Fold ceildiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) ceildiv 64 = i * 2.
if (rhsConst) {
if (rhsConst->getValue() == 1)
return lhs;
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
if (lBin && lBin->getKind() == Kind::Mul) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
// rhsConst is known to be positive if a constant.
if (lrhs->getValue() % rhsConst->getValue() == 0)
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
}
}
}
return nullptr;
}
AffineExprRef AffineBinaryOpExpr::simplifyMod(AffineExprRef lhs,
AffineExprRef rhs,
MLIRContext *context) {
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
if (lhsConst && rhsConst)
return AffineConstantExpr::get(
mod(lhsConst->getValue(), rhsConst->getValue()), context);
// Fold modulo of an expression that is known to be a multiple of a constant
// to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
// mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
if (rhsConst) {
// rhsConst is known to be positive if a constant.
if (lhs->getLargestKnownDivisor() % rhsConst->getValue() == 0)
return AffineConstantExpr::get(0, context);
}
return nullptr;
// TODO(bondhugula): In general, this can be simplified more by using the GCD
// test, or in general using quantifier elimination (add two new variables q
// and r, and eliminate all variables from the linear system other than r. All
// of this can be done through mlir/Analysis/'s FlatAffineConstraints.
}
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible. Returns false if the folding happens,
/// true otherwise.

View File

@ -40,6 +40,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir;
using namespace mlir::detail;
void Identifier::print(raw_ostream &os) const { os << str(); }
@ -578,28 +579,28 @@ void ModulePrinter::printAffineExprInternal(
AffineExprRef expr, BindingStrength enclosingTightness) {
const char *binopSpelling = nullptr;
switch (expr->getKind()) {
case AffineExpr::Kind::SymbolId:
case AffineExprKind::SymbolId:
os << 's' << expr.cast<AffineSymbolExprRef>()->getPosition();
return;
case AffineExpr::Kind::DimId:
case AffineExprKind::DimId:
os << 'd' << expr.cast<AffineDimExprRef>()->getPosition();
return;
case AffineExpr::Kind::Constant:
case AffineExprKind::Constant:
os << expr.cast<AffineConstantExprRef>()->getValue();
return;
case AffineExpr::Kind::Add:
case AffineExprKind::Add:
binopSpelling = " + ";
break;
case AffineExpr::Kind::Mul:
case AffineExprKind::Mul:
binopSpelling = " * ";
break;
case AffineExpr::Kind::FloorDiv:
case AffineExprKind::FloorDiv:
binopSpelling = " floordiv ";
break;
case AffineExpr::Kind::CeilDiv:
case AffineExprKind::CeilDiv:
binopSpelling = " ceildiv ";
break;
case AffineExpr::Kind::Mod:
case AffineExprKind::Mod:
binopSpelling = " mod ";
break;
}
@ -607,7 +608,7 @@ void ModulePrinter::printAffineExprInternal(
auto binOp = expr.cast<AffineBinaryOpExprRef>();
// Handle tightly binding binary operators.
if (binOp->getKind() != AffineExpr::Kind::Add) {
if (binOp->getKind() != AffineExprKind::Add) {
if (enclosingTightness == BindingStrength::Strong)
os << '(';
@ -628,7 +629,7 @@ void ModulePrinter::printAffineExprInternal(
// subtraction.
AffineExprRef rhsExpr = binOp->getRHS();
if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExprRef>()) {
if (rhs->getKind() == AffineExpr::Kind::Mul) {
if (rhs->getKind() == AffineExprKind::Mul) {
AffineExprRef rrhsExpr = rhs->getRHS();
if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExprRef>()) {
if (rrhs->getValue() == -1) {

View File

@ -155,72 +155,68 @@ AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
return AffineMap::get(dimCount, symbolCount, results, rangeSizes, context);
}
AffineExprRef Builder::getDimExpr(unsigned position) {
return AffineDimExpr::get(position, context);
AffineExprRef Builder::getAffineDimExpr(unsigned position) {
return mlir::getAffineDimExpr(position, context);
}
AffineExprRef Builder::getSymbolExpr(unsigned position) {
return AffineSymbolExpr::get(position, context);
AffineExprRef Builder::getAffineSymbolExpr(unsigned position) {
return mlir::getAffineSymbolExpr(position, context);
}
AffineExprRef Builder::getConstantExpr(int64_t constant) {
return AffineConstantExpr::get(constant, context);
AffineExprRef Builder::getAffineConstantExpr(int64_t constant) {
return mlir::getAffineConstantExpr(constant, context);
}
AffineExprRef Builder::getAddExpr(AffineExprRef lhs, AffineExprRef rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::Add, lhs, rhs, context);
return lhs + rhs;
}
AffineExprRef Builder::getAddExpr(AffineExprRef lhs, int64_t rhs) {
return AffineBinaryOpExpr::getAdd(lhs, rhs, context);
return lhs + rhs;
}
AffineExprRef Builder::getMulExpr(AffineExprRef lhs, AffineExprRef rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::Mul, lhs, rhs, context);
return lhs * rhs;
}
// Most multiply expressions are pure affine (rhs is a constant).
AffineExprRef Builder::getMulExpr(AffineExprRef lhs, int64_t rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::Mul, lhs,
getConstantExpr(rhs), context);
return lhs * rhs;
}
AffineExprRef Builder::getSubExpr(AffineExprRef lhs, AffineExprRef rhs) {
return getAddExpr(lhs, getMulExpr(rhs, getConstantExpr(-1)));
return lhs - rhs;
}
AffineExprRef Builder::getSubExpr(AffineExprRef lhs, int64_t rhs) {
return AffineBinaryOpExpr::getAdd(lhs, -rhs, context);
return lhs - rhs;
}
AffineExprRef Builder::getModExpr(AffineExprRef lhs, AffineExprRef rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::Mod, lhs, rhs, context);
return lhs % rhs;
}
// Most modulo expressions are pure affine.
AffineExprRef Builder::getModExpr(AffineExprRef lhs, uint64_t rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::Mod, lhs,
getConstantExpr(rhs), context);
return lhs % rhs;
}
AffineExprRef Builder::getFloorDivExpr(AffineExprRef lhs, AffineExprRef rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::FloorDiv, lhs, rhs, context);
return lhs.floorDiv(rhs);
}
// Most floordiv expressions are pure affine.
AffineExprRef Builder::getFloorDivExpr(AffineExprRef lhs, uint64_t rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::FloorDiv, lhs,
getConstantExpr(rhs), context);
return lhs.floorDiv(rhs);
}
AffineExprRef Builder::getCeilDivExpr(AffineExprRef lhs, AffineExprRef rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::CeilDiv, lhs, rhs, context);
return lhs.ceilDiv(rhs);
}
// Most ceildiv expressions are pure affine.
AffineExprRef Builder::getCeilDivExpr(AffineExprRef lhs, uint64_t rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::CeilDiv, lhs,
getConstantExpr(rhs), context);
return lhs.ceilDiv(rhs);
}
IntegerSet *Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
@ -231,22 +227,22 @@ IntegerSet *Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
AffineMap *Builder::getConstantAffineMap(int64_t val) {
return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
{getConstantExpr(val)}, {}, context);
{getAffineConstantExpr(val)}, {}, context);
}
AffineMap *Builder::getDimIdentityMap() {
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {getDimExpr(0)}, {},
context);
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
{getAffineDimExpr(0)}, {}, context);
}
AffineMap *Builder::getSymbolIdentityMap() {
return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, {getSymbolExpr(0)},
{}, context);
return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
{getAffineSymbolExpr(0)}, {}, context);
}
AffineMap *Builder::getSingleDimShiftAffineMap(int64_t shift) {
// expr = d0 + shift.
auto expr = getDimExpr(0) + shift;
auto expr = getAffineDimExpr(0) + shift;
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}, {}, context);
}

View File

@ -27,6 +27,7 @@
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringMap.h"
@ -34,6 +35,7 @@
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::detail;
using namespace llvm;
namespace {
@ -227,7 +229,7 @@ public:
DenseMap<std::tuple<unsigned, AffineExprRef, AffineExprRef>, AffineExprRef>
affineExprs;
// Uniqui'ing of AffineDimExpr, AffineSymbolExpr's by their position.
// Uniqui'ing of AffineDimExprRef, AffineSymbolExprRef's by their position.
std::vector<AffineDimExpr *> dimExprs;
std::vector<AffineSymbolExpr *> symbolExprs;
@ -833,12 +835,185 @@ AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
return *existing.first = res;
}
/// Simplify add expression. Return nullptr if it can't be simplified.
static AffineExprRef simplifyAdd(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context) {
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
// Fold if both LHS, RHS are a constant.
if (lhsConst && rhsConst)
return getAffineConstantExpr(lhsConst->getValue() + rhsConst->getValue(),
context);
// Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
// If only one of them is a symbolic expressions, make it the RHS.
if (lhs.isa<AffineConstantExprRef>() ||
(lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())) {
return rhs + lhs;
}
// At this point, if there was a constant, it would be on the right.
// Addition with a zero is a noop, return the other input.
if (rhsConst) {
if (rhsConst->getValue() == 0)
return lhs;
}
// Fold successive additions like (d0 + 2) + 3 into d0 + 5.
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
if (lBin && rhsConst && lBin->getKind() == AffineExprKind::Add) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>())
return lBin->getLHS() + (lrhs->getValue() + rhsConst->getValue());
}
// When doing successive additions, bring constant to the right: turn (d0 + 2)
// + d1 into (d0 + d1) + 2.
if (lBin && lBin->getKind() == AffineExprKind::Add) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
return lBin->getLHS() + rhs + lrhs;
}
}
return nullptr;
}
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
static AffineExprRef simplifyMul(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context) {
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
if (lhsConst && rhsConst)
return getAffineConstantExpr(lhsConst->getValue() * rhsConst->getValue(),
context);
assert(lhs->isSymbolicOrConstant() || rhs->isSymbolicOrConstant());
// Canonicalize the mul expression so that the constant/symbolic term is the
// RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
// constant. (Note that a constant is trivially symbolic).
if (!rhs->isSymbolicOrConstant() || lhs.isa<AffineConstantExprRef>()) {
// At least one of them has to be symbolic.
return rhs * lhs;
}
// At this point, if there was a constant, it would be on the right.
// Multiplication with a one is a noop, return the other input.
if (rhsConst) {
if (rhsConst->getValue() == 1)
return lhs;
// Multiplication with zero.
if (rhsConst->getValue() == 0)
return rhsConst;
}
// Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
if (lBin && rhsConst && lBin->getKind() == AffineExprKind::Mul) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>())
return lBin->getLHS() * (lrhs->getValue() * rhsConst->getValue());
}
// When doing successive multiplication, bring constant to the right: turn (d0
// * 2) * d1 into (d0 * d1) * 2.
if (lBin && lBin->getKind() == AffineExprKind::Mul) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
return (lBin->getLHS() * rhs) * lrhs;
}
}
return nullptr;
}
static AffineExprRef simplifyFloorDiv(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context) {
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
if (lhsConst && rhsConst)
return getAffineConstantExpr(
floorDiv(lhsConst->getValue(), rhsConst->getValue()), context);
// Fold floordiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) floordiv 64 = i * 2.
if (rhsConst) {
if (rhsConst->getValue() == 1)
return lhs;
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
if (lBin && lBin->getKind() == AffineExprKind::Mul) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
// rhsConst is known to be positive if a constant.
if (lrhs->getValue() % rhsConst->getValue() == 0)
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
}
}
}
return nullptr;
}
static AffineExprRef simplifyCeilDiv(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context) {
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
if (lhsConst && rhsConst)
return getAffineConstantExpr(
ceilDiv(lhsConst->getValue(), rhsConst->getValue()), context);
// Fold ceildiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) ceildiv 64 = i * 2.
if (rhsConst) {
if (rhsConst->getValue() == 1)
return lhs;
auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
if (lBin && lBin->getKind() == AffineExprKind::Mul) {
if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
// rhsConst is known to be positive if a constant.
if (lrhs->getValue() % rhsConst->getValue() == 0)
return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
}
}
}
return nullptr;
}
static AffineExprRef simplifyMod(AffineExprRef lhs, AffineExprRef rhs,
MLIRContext *context) {
auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
if (lhsConst && rhsConst)
return getAffineConstantExpr(
mod(lhsConst->getValue(), rhsConst->getValue()), context);
// Fold modulo of an expression that is known to be a multiple of a constant
// to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
// mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
if (rhsConst) {
// rhsConst is known to be positive if a constant.
if (lhs->getLargestKnownDivisor() % rhsConst->getValue() == 0)
return getAffineConstantExpr(0, context);
}
return nullptr;
// TODO(bondhugula): In general, this can be simplified more by using the GCD
// test, or in general using quantifier elimination (add two new variables q
// and r, and eliminate all variables from the linear system other than r. All
// of this can be done through mlir/Analysis/'s FlatAffineConstraints.
}
/// Return a binary affine op expression with the specified op type and
/// operands: if it doesn't exist, create it and store it; if it is already
/// present, return from the list. The stored expressions are unique: they are
/// constructed and stored in a simplified/canonicalized form. The result after
/// simplification could be any form of affine expression.
AffineExprRef AffineBinaryOpExpr::get(AffineExpr::Kind kind, AffineExprRef lhs,
AffineExprRef AffineBinaryOpExpr::get(AffineExprKind kind, AffineExprRef lhs,
AffineExprRef rhs, MLIRContext *context) {
auto &impl = context->getImpl();
@ -846,25 +1021,25 @@ AffineExprRef AffineBinaryOpExpr::get(AffineExpr::Kind kind, AffineExprRef lhs,
auto keyValue = std::make_tuple((unsigned)kind, lhs, rhs);
auto cached = impl.affineExprs.find(keyValue);
if (cached != impl.affineExprs.end())
return cached->second;
return static_cast<AffineExpr *>(cached->second);
// Simplify the expression if possible.
AffineExprRef simplified(nullptr);
AffineExprRef simplified;
switch (kind) {
case Kind::Add:
simplified = AffineBinaryOpExpr::simplifyAdd(lhs, rhs, context);
case AffineExprKind::Add:
simplified = simplifyAdd(lhs, rhs, context);
break;
case Kind::Mul:
simplified = AffineBinaryOpExpr::simplifyMul(lhs, rhs, context);
case AffineExprKind::Mul:
simplified = simplifyMul(lhs, rhs, context);
break;
case Kind::FloorDiv:
simplified = AffineBinaryOpExpr::simplifyFloorDiv(lhs, rhs, context);
case AffineExprKind::FloorDiv:
simplified = simplifyFloorDiv(lhs, rhs, context);
break;
case Kind::CeilDiv:
simplified = AffineBinaryOpExpr::simplifyCeilDiv(lhs, rhs, context);
case AffineExprKind::CeilDiv:
simplified = simplifyCeilDiv(lhs, rhs, context);
break;
case Kind::Mod:
simplified = AffineBinaryOpExpr::simplifyMod(lhs, rhs, context);
case AffineExprKind::Mod:
simplified = simplifyMod(lhs, rhs, context);
break;
default:
llvm_unreachable("unexpected binary affine expr");
@ -872,7 +1047,7 @@ AffineExprRef AffineBinaryOpExpr::get(AffineExpr::Kind kind, AffineExprRef lhs,
// The simplified one would have already been cached; just return it.
if (simplified)
return simplified;
return static_cast<AffineExpr *>(simplified);
// An expression with these operands will already be in the
// simplified/canonical form. Create and store it.
@ -885,7 +1060,7 @@ AffineExprRef AffineBinaryOpExpr::get(AffineExpr::Kind kind, AffineExprRef lhs,
return result;
}
AffineExprRef AffineDimExpr::get(unsigned position, MLIRContext *context) {
AffineExprRef mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
auto &impl = context->getImpl();
// Check if we need to resize.
@ -902,7 +1077,8 @@ AffineExprRef AffineDimExpr::get(unsigned position, MLIRContext *context) {
return result;
}
AffineExprRef AffineSymbolExpr::get(unsigned position, MLIRContext *context) {
AffineExprRef mlir::getAffineSymbolExpr(unsigned position,
MLIRContext *context) {
auto &impl = context->getImpl();
// Check if we need to resize.
@ -919,7 +1095,8 @@ AffineExprRef AffineSymbolExpr::get(unsigned position, MLIRContext *context) {
return result;
}
AffineExprRef AffineConstantExpr::get(int64_t constant, MLIRContext *context) {
AffineExprRef mlir::getAffineConstantExpr(int64_t constant,
MLIRContext *context) {
auto &impl = context->getImpl();
auto *&result = impl.constExprs[constant];

View File

@ -908,7 +908,7 @@ AffineExprRef AffineParser::getBinaryAffineOpExpr(AffineLowPrecOp op,
return builder.getAddExpr(lhs, rhs);
case AffineLowPrecOp::Sub:
return builder.getAddExpr(
lhs, builder.getMulExpr(rhs, builder.getConstantExpr(-1)));
lhs, builder.getMulExpr(rhs, builder.getAffineConstantExpr(-1)));
case AffineLowPrecOp::LNoOp:
llvm_unreachable("can't create affine expression for null low prec op");
return nullptr;
@ -1021,7 +1021,7 @@ AffineExprRef AffineParser::parseNegateExpression(AffineExprRef lhs) {
// Extra error message although parseAffineOperandExpr would have
// complained. Leads to a better diagnostic.
return (emitError("missing operand of negation"), nullptr);
auto minusOne = builder.getConstantExpr(-1);
auto minusOne = builder.getAffineConstantExpr(-1);
return builder.getMulExpr(minusOne, operand);
}
@ -1052,7 +1052,7 @@ AffineExprRef AffineParser::parseIntegerExpr() {
return (emitError("constant too large for index"), nullptr);
consumeToken(Token::integer);
return builder.getConstantExpr((int64_t)val.getValue());
return builder.getAffineConstantExpr((int64_t)val.getValue());
}
/// Parses an expression that can be a valid operand of an affine expression.
@ -1196,7 +1196,7 @@ ParseResult AffineParser::parseIdentifierDefinition(AffineExprRef idExpr) {
ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) {
consumeToken(Token::l_square);
auto parseElt = [&]() -> ParseResult {
auto symbol = AffineSymbolExpr::get(numSymbols++, getContext());
auto symbol = getAffineSymbolExpr(numSymbols++, getContext());
return parseIdentifierDefinition(symbol);
};
return parseCommaSeparatedListUntil(Token::r_square, parseElt);
@ -1209,7 +1209,7 @@ ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
return ParseFailure;
auto parseElt = [&]() -> ParseResult {
auto dimension = AffineDimExpr::get(numDims++, getContext());
auto dimension = getAffineDimExpr(numDims++, getContext());
return parseIdentifierDefinition(dimension);
};
return parseCommaSeparatedListUntil(Token::r_paren, parseElt);

View File

@ -266,7 +266,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
// this unrolled instance.
if (!forStmt->use_empty()) {
// iv' = iv + 1/2/3...unrollFactor-1;
auto d0 = builder.getDimExpr(0);
auto d0 = builder.getAffineDimExpr(0);
auto *bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)

View File

@ -220,7 +220,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// this unrolled instance.
if (!forStmt->use_empty()) {
// iv' = iv + i, i = 1 to unrollJamFactor-1.
auto d0 = builder.getDimExpr(0);
auto d0 = builder.getAffineDimExpr(0);
auto *bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)

View File

@ -116,7 +116,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
->getResult());
auto d0 = bInner.getDimExpr(0);
auto d0 = bInner.getAffineDimExpr(0);
auto *modTwoMap = bInner.getAffineMap(1, 0, {d0 % 2}, {});
auto ivModTwoOp =
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);