diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index ded953f7a7c0..58b4fbc3be11 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -189,7 +189,7 @@ public: /// A symbolic identifier appearing in an affine expression. class AffineSymbolExpr : public AffineExpr { public: - using ImplType = detail::AffineSymbolExprStorage; + using ImplType = detail::AffineDimExprStorage; /* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr); unsigned getPosition() const; }; diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index c74c814eb849..864fd960a7db 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -82,6 +82,9 @@ public: /// Returns the diagnostic engine for this context. DiagnosticEngine &getDiagEngine(); + /// Returns the storage uniquer used for creating affine constructs. + StorageUniquer &getAffineUniquer(); + /// Returns the storage uniquer used for constructing type storage instances. /// This should not be used directly. StorageUniquer &getTypeUniquer(); diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h index 5b408f3eeb17..fc20db2f65dd 100644 --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -35,9 +35,9 @@ struct StorageUniquerImpl; /// /// For non-parametric storage classes, i.e. those that are solely uniqued by /// their kind, nothing else is needed. Instances of these classes can be -/// queried with 'getSimple'. +/// created by calling `get` without trailing arguments. /// -/// Otherwise, the parametric storage classes may be queried with 'getComplex', +/// Otherwise, the parametric storage classes may be created with `get`, /// and must respect the following: /// - Define a type alias, KeyTy, to a type that uniquely identifies the /// instance of the storage class within its kind. diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 739be157469b..03dd4b81c454 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -20,18 +20,17 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::detail; -MLIRContext *AffineExpr::getContext() const { - return expr->contextAndKind.getPointer(); -} +MLIRContext *AffineExpr::getContext() const { return expr->context; } AffineExprKind AffineExpr::getKind() const { - return expr->contextAndKind.getInt(); + return static_cast(expr->getKind()); } /// Walk all of the AffineExprs in this subgraph in postorder. @@ -51,6 +50,23 @@ void AffineExpr::walk(std::function callback) const { AffineExprWalker(callback).walkPostOrder(*this); } +// Dispatch affine expression construction based on kind. +AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, + AffineExpr rhs) { + if (kind == AffineExprKind::Add) + return lhs + rhs; + if (kind == AffineExprKind::Mul) + return lhs * rhs; + if (kind == AffineExprKind::FloorDiv) + return lhs.floorDiv(rhs); + if (kind == AffineExprKind::CeilDiv) + return lhs.ceilDiv(rhs); + if (kind == AffineExprKind::Mod) + return lhs % rhs; + + llvm_unreachable("unknown binary operation on affine expressions"); +} + /// This method substitutes any uses of dimensions and symbols (e.g. /// dim#0 with dimReplacements[0]) and returns the modified expression tree. AffineExpr @@ -231,65 +247,313 @@ unsigned AffineDimExpr::getPosition() const { return static_cast(expr)->position; } +static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, + MLIRContext *context) { + auto assignCtx = [context](AffineDimExprStorage *storage) { + storage->context = context; + }; + + StorageUniquer &uniquer = context->getAffineUniquer(); + return uniquer.get( + assignCtx, static_cast(kind), position); +} + +AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { + return getAffineDimOrSymbol(AffineExprKind::DimId, position, context); +} + AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {} unsigned AffineSymbolExpr::getPosition() const { return static_cast(expr)->position; } +AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) { + return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context); + ; +} + AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {} int64_t AffineConstantExpr::getValue() const { return static_cast(expr)->constant; } +AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) { + auto assignCtx = [context](AffineConstantExprStorage *storage) { + storage->context = context; + }; + + StorageUniquer &uniquer = context->getAffineUniquer(); + return uniquer.get( + assignCtx, static_cast(AffineExprKind::Constant), constant); +} + +/// Simplify add expression. Return nullptr if it can't be simplified. +static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); + // Fold if both LHS, RHS are a constant. + if (lhsConst && rhsConst) + return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(), + lhs.getContext()); + + // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4). + // If only one of them is a symbolic expressions, make it the RHS. + if (lhs.isa() || + (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) { + return rhs + lhs; + } + + // At this point, if there was a constant, it would be on the right. + + // Addition with a zero is a noop, return the other input. + if (rhsConst) { + if (rhsConst.getValue() == 0) + return lhs; + } + // Fold successive additions like (d0 + 2) + 3 into d0 + 5. + auto lBin = lhs.dyn_cast(); + if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) { + if (auto lrhs = lBin.getRHS().dyn_cast()) + return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue()); + } + + // When doing successive additions, bring constant to the right: turn (d0 + 2) + // + d1 into (d0 + d1) + 2. + if (lBin && lBin.getKind() == AffineExprKind::Add) { + if (auto lrhs = lBin.getRHS().dyn_cast()) { + return lBin.getLHS() + rhs + lrhs; + } + } + + // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This + // leads to a much more efficient form when 'c' is a power of two, and in + // general a more compact and readable form. + + // Process '(expr floordiv c) * (-c)'. + AffineBinaryOpExpr rBinOpExpr = rhs.dyn_cast(); + if (!rBinOpExpr) + return nullptr; + + auto lrhs = rBinOpExpr.getLHS(); + auto rrhs = rBinOpExpr.getRHS(); + + // Process lrhs, which is 'expr floordiv c'. + AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast(); + if (!lrBinOpExpr) + return nullptr; + + auto llrhs = lrBinOpExpr.getLHS(); + auto rlrhs = lrBinOpExpr.getRHS(); + + if (lhs == llrhs && rlrhs == -rrhs) { + return lhs % rlrhs; + } + return nullptr; +} + AffineExpr AffineExpr::operator+(int64_t v) const { - return AffineBinaryOpExprStorage::get(AffineExprKind::Add, expr, - getAffineConstantExpr(v, getContext())); + return *this + getAffineConstantExpr(v, getContext()); } AffineExpr AffineExpr::operator+(AffineExpr other) const { - return AffineBinaryOpExprStorage::get(AffineExprKind::Add, expr, other.expr); + if (auto simplified = simplifyAdd(*this, other)) + return simplified; + + StorageUniquer &uniquer = getContext()->getAffineUniquer(); + return uniquer.get( + /*initFn=*/{}, static_cast(AffineExprKind::Add), *this, other); } + +/// Simplify a multiply expression. Return nullptr if it can't be simplified. +static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) { + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); + + if (lhsConst && rhsConst) + return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(), + lhs.getContext()); + + assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant()); + + // Canonicalize the mul expression so that the constant/symbolic term is the + // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a + // constant. (Note that a constant is trivially symbolic). + if (!rhs.isSymbolicOrConstant() || lhs.isa()) { + // At least one of them has to be symbolic. + return rhs * lhs; + } + + // At this point, if there was a constant, it would be on the right. + + // Multiplication with a one is a noop, return the other input. + if (rhsConst) { + if (rhsConst.getValue() == 1) + return lhs; + // Multiplication with zero. + if (rhsConst.getValue() == 0) + return rhsConst; + } + + // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6. + auto lBin = lhs.dyn_cast(); + if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) { + if (auto lrhs = lBin.getRHS().dyn_cast()) + return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue()); + } + + // When doing successive multiplication, bring constant to the right: turn (d0 + // * 2) * d1 into (d0 * d1) * 2. + if (lBin && lBin.getKind() == AffineExprKind::Mul) { + if (auto lrhs = lBin.getRHS().dyn_cast()) { + return (lBin.getLHS() * rhs) * lrhs; + } + } + + return nullptr; +} + AffineExpr AffineExpr::operator*(int64_t v) const { - return AffineBinaryOpExprStorage::get(AffineExprKind::Mul, expr, - getAffineConstantExpr(v, getContext())); + return *this * getAffineConstantExpr(v, getContext()); } AffineExpr AffineExpr::operator*(AffineExpr other) const { - return AffineBinaryOpExprStorage::get(AffineExprKind::Mul, expr, other.expr); + if (auto simplified = simplifyMul(*this, other)) + return simplified; + + StorageUniquer &uniquer = getContext()->getAffineUniquer(); + return uniquer.get( + /*initFn=*/{}, static_cast(AffineExprKind::Mul), *this, other); } + // Unary minus, delegate to operator*. AffineExpr AffineExpr::operator-() const { - return AffineBinaryOpExprStorage::get( - AffineExprKind::Mul, expr, getAffineConstantExpr(-1, getContext())); + return *this * getAffineConstantExpr(-1, getContext()); } + // Delegate to operator+. AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); } AffineExpr AffineExpr::operator-(AffineExpr other) const { return *this + (-other); } + +static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); + + if (!rhsConst || rhsConst.getValue() < 1) + return nullptr; + + if (lhsConst) + return getAffineConstantExpr( + floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); + + // Fold floordiv of a multiply with a constant that is a multiple of the + // divisor. Eg: (i * 128) floordiv 64 = i * 2. + if (rhsConst.getValue() == 1) + return lhs; + + auto lBin = lhs.dyn_cast(); + if (lBin && lBin.getKind() == AffineExprKind::Mul) { + if (auto lrhs = lBin.getRHS().dyn_cast()) { + // rhsConst is known to be positive if a constant. + if (lrhs.getValue() % rhsConst.getValue() == 0) + return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); + } + } + + return nullptr; +} + AffineExpr AffineExpr::floorDiv(uint64_t v) const { - return AffineBinaryOpExprStorage::get(AffineExprKind::FloorDiv, expr, - getAffineConstantExpr(v, getContext())); + return floorDiv(getAffineConstantExpr(v, getContext())); } AffineExpr AffineExpr::floorDiv(AffineExpr other) const { - return AffineBinaryOpExprStorage::get(AffineExprKind::FloorDiv, expr, - other.expr); + if (auto simplified = simplifyFloorDiv(*this, other)) + return simplified; + + StorageUniquer &uniquer = getContext()->getAffineUniquer(); + return uniquer.get( + /*initFn=*/{}, static_cast(AffineExprKind::FloorDiv), *this, + other); } + +static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); + + if (!rhsConst || rhsConst.getValue() < 1) + return nullptr; + + if (lhsConst) + return getAffineConstantExpr( + ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); + + // Fold ceildiv of a multiply with a constant that is a multiple of the + // divisor. Eg: (i * 128) ceildiv 64 = i * 2. + if (rhsConst.getValue() == 1) + return lhs; + + auto lBin = lhs.dyn_cast(); + if (lBin && lBin.getKind() == AffineExprKind::Mul) { + if (auto lrhs = lBin.getRHS().dyn_cast()) { + // rhsConst is known to be positive if a constant. + if (lrhs.getValue() % rhsConst.getValue() == 0) + return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); + } + } + + return nullptr; +} + AffineExpr AffineExpr::ceilDiv(uint64_t v) const { - return AffineBinaryOpExprStorage::get(AffineExprKind::CeilDiv, expr, - getAffineConstantExpr(v, getContext())); + return ceilDiv(getAffineConstantExpr(v, getContext())); } AffineExpr AffineExpr::ceilDiv(AffineExpr other) const { - return AffineBinaryOpExprStorage::get(AffineExprKind::CeilDiv, expr, - other.expr); + if (auto simplified = simplifyCeilDiv(*this, other)) + return simplified; + + StorageUniquer &uniquer = getContext()->getAffineUniquer(); + return uniquer.get( + /*initFn=*/{}, static_cast(AffineExprKind::CeilDiv), *this, + other); } + +static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); + + if (!rhsConst || rhsConst.getValue() < 1) + return nullptr; + + if (lhsConst) + return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()), + lhs.getContext()); + + // Fold modulo of an expression that is known to be a multiple of a constant + // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128) + // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0. + if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0) + return getAffineConstantExpr(0, lhs.getContext()); + + return nullptr; + // TODO(bondhugula): In general, this can be simplified more by using the GCD + // test, or in general using quantifier elimination (add two new variables q + // and r, and eliminate all variables from the linear system other than r. All + // of this can be done through mlir/Analysis/'s FlatAffineConstraints. +} + AffineExpr AffineExpr::operator%(uint64_t v) const { - return AffineBinaryOpExprStorage::get(AffineExprKind::Mod, expr, - getAffineConstantExpr(v, getContext())); + return *this % getAffineConstantExpr(v, getContext()); } AffineExpr AffineExpr::operator%(AffineExpr other) const { - return AffineBinaryOpExprStorage::get(AffineExprKind::Mod, expr, other.expr); + if (auto simplified = simplifyMod(*this, other)) + return simplified; + + StorageUniquer &uniquer = getContext()->getAffineUniquer(); + return uniquer.get( + /*initFn=*/{}, static_cast(AffineExprKind::Mod), *this, other); } + AffineExpr AffineExpr::compose(AffineMap map) const { SmallVector dimReplacements(map.getResults().begin(), map.getResults().end()); diff --git a/mlir/lib/IR/AffineExprDetail.h b/mlir/lib/IR/AffineExprDetail.h index bca0957bcb8f..214fee65056e 100644 --- a/mlir/lib/IR/AffineExprDetail.h +++ b/mlir/lib/IR/AffineExprDetail.h @@ -25,7 +25,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/MLIRContext.h" -#include "llvm/ADT/PointerIntPair.h" +#include "mlir/Support/StorageUniquer.h" namespace mlir { @@ -34,42 +34,61 @@ class MLIRContext; namespace detail { /// Base storage class appearing in an affine expression. -struct AffineExprStorage { - AffineExprStorage(AffineExprKind kind, MLIRContext *context) - : contextAndKind(context, kind) {} - llvm::PointerIntPair contextAndKind; +struct AffineExprStorage : public StorageUniquer::BaseStorage { + MLIRContext *context; }; /// A binary operation appearing in an affine expression. struct AffineBinaryOpExprStorage : public AffineExprStorage { - AffineBinaryOpExprStorage(AffineExprStorage base, AffineExpr lhs, - AffineExpr rhs) - : AffineExprStorage(base), lhs(lhs), rhs(rhs) {} - static AffineExpr get(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs); + using KeyTy = std::pair; + + bool operator==(const KeyTy &key) const { + return key.first == lhs && key.second == rhs; + } + + static AffineBinaryOpExprStorage * + construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { + auto *result = allocator.allocate(); + result->lhs = key.first; + result->rhs = key.second; + result->context = result->lhs.getContext(); + return result; + } + AffineExpr lhs; AffineExpr rhs; }; -/// A dimensional identifier appearing in an affine expression. +/// A dimensional or symbolic identifier appearing in an affine expression. struct AffineDimExprStorage : public AffineExprStorage { - AffineDimExprStorage(AffineExprStorage base, unsigned position) - : AffineExprStorage(base), position(position) {} - /// Position of this identifier in the argument list. - unsigned position; -}; + using KeyTy = unsigned; -/// A symbolic identifier appearing in an affine expression. -struct AffineSymbolExprStorage : public AffineExprStorage { - AffineSymbolExprStorage(AffineExprStorage base, unsigned position) - : AffineExprStorage(base), position(position) {} - /// Position of this identifier in the symbol list. + bool operator==(const KeyTy &key) const { return position == key; } + + static AffineDimExprStorage * + construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { + auto *result = allocator.allocate(); + result->position = key; + return result; + } + + /// Position of this identifier in the argument list. unsigned position; }; /// An integer constant appearing in affine expression. struct AffineConstantExprStorage : public AffineExprStorage { - AffineConstantExprStorage(AffineExprStorage base, int64_t constant) - : AffineExprStorage(base), constant(constant) {} + using KeyTy = int64_t; + + bool operator==(const KeyTy &key) const { return constant == key; } + + static AffineConstantExprStorage * + construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { + auto *result = allocator.allocate(); + result->constant = key; + return result; + } + // The constant. int64_t constant; }; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index ab0595454138..b54e9565e354 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -33,7 +33,6 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/Types.h" -#include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" @@ -352,17 +351,8 @@ public: using IntegerSets = DenseSet; IntegerSets integerSets; - // Affine binary op expression uniquing. Figure out uniquing of dimensional - // or symbolic identifiers. - DenseMap, AffineExpr> - affineExprs; - - // Uniqui'ing of AffineDimExpr, AffineSymbolExpr's by their position. - std::vector dimExprs; - std::vector symbolExprs; - - // Uniqui'ing of AffineConstantExprStorage using constant value as key. - DenseMap constExprs; + // Affine expression uniqui'ing. + StorageUniquer affineUniquer; //===--------------------------------------------------------------------===// // SDBM uniquing @@ -918,9 +908,13 @@ AttributeListStorage::get(ArrayRef attrs) { } //===----------------------------------------------------------------------===// -// AffineMap and AffineExpr uniquing +// AffineMap uniquing //===----------------------------------------------------------------------===// +StorageUniquer &MLIRContext::getAffineUniquer() { + return getImpl().affineUniquer; +} + AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, ArrayRef results, ArrayRef rangeSizes) { @@ -947,300 +941,6 @@ AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, }); } -/// Simplify add expression. Return nullptr if it can't be simplified. -static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { - auto lhsConst = lhs.dyn_cast(); - auto rhsConst = rhs.dyn_cast(); - // Fold if both LHS, RHS are a constant. - if (lhsConst && rhsConst) - return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(), - lhs.getContext()); - - // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4). - // If only one of them is a symbolic expressions, make it the RHS. - if (lhs.isa() || - (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) { - return rhs + lhs; - } - - // At this point, if there was a constant, it would be on the right. - - // Addition with a zero is a noop, return the other input. - if (rhsConst) { - if (rhsConst.getValue() == 0) - return lhs; - } - // Fold successive additions like (d0 + 2) + 3 into d0 + 5. - auto lBin = lhs.dyn_cast(); - if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) { - if (auto lrhs = lBin.getRHS().dyn_cast()) - return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue()); - } - - // When doing successive additions, bring constant to the right: turn (d0 + 2) - // + d1 into (d0 + d1) + 2. - if (lBin && lBin.getKind() == AffineExprKind::Add) { - if (auto lrhs = lBin.getRHS().dyn_cast()) { - return lBin.getLHS() + rhs + lrhs; - } - } - - // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This - // leads to a much more efficient form when 'c' is a power of two, and in - // general a more compact and readable form. - - // Process '(expr floordiv c) * (-c)'. - AffineBinaryOpExpr rBinOpExpr = rhs.dyn_cast(); - if (!rBinOpExpr) - return nullptr; - - auto lrhs = rBinOpExpr.getLHS(); - auto rrhs = rBinOpExpr.getRHS(); - - // Process lrhs, which is 'expr floordiv c'. - AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast(); - if (!lrBinOpExpr) - return nullptr; - - auto llrhs = lrBinOpExpr.getLHS(); - auto rlrhs = lrBinOpExpr.getRHS(); - - if (lhs == llrhs && rlrhs == -rrhs) { - return lhs % rlrhs; - } - return nullptr; -} - -/// Simplify a multiply expression. Return nullptr if it can't be simplified. -static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) { - auto lhsConst = lhs.dyn_cast(); - auto rhsConst = rhs.dyn_cast(); - - if (lhsConst && rhsConst) - return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(), - lhs.getContext()); - - assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant()); - - // Canonicalize the mul expression so that the constant/symbolic term is the - // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a - // constant. (Note that a constant is trivially symbolic). - if (!rhs.isSymbolicOrConstant() || lhs.isa()) { - // At least one of them has to be symbolic. - return rhs * lhs; - } - - // At this point, if there was a constant, it would be on the right. - - // Multiplication with a one is a noop, return the other input. - if (rhsConst) { - if (rhsConst.getValue() == 1) - return lhs; - // Multiplication with zero. - if (rhsConst.getValue() == 0) - return rhsConst; - } - - // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6. - auto lBin = lhs.dyn_cast(); - if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) { - if (auto lrhs = lBin.getRHS().dyn_cast()) - return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue()); - } - - // When doing successive multiplication, bring constant to the right: turn (d0 - // * 2) * d1 into (d0 * d1) * 2. - if (lBin && lBin.getKind() == AffineExprKind::Mul) { - if (auto lrhs = lBin.getRHS().dyn_cast()) { - return (lBin.getLHS() * rhs) * lrhs; - } - } - - return nullptr; -} - -static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { - auto lhsConst = lhs.dyn_cast(); - auto rhsConst = rhs.dyn_cast(); - - if (!rhsConst || rhsConst.getValue() < 1) - return nullptr; - - if (lhsConst) - return getAffineConstantExpr( - floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); - - // Fold floordiv of a multiply with a constant that is a multiple of the - // divisor. Eg: (i * 128) floordiv 64 = i * 2. - if (rhsConst.getValue() == 1) - return lhs; - - auto lBin = lhs.dyn_cast(); - if (lBin && lBin.getKind() == AffineExprKind::Mul) { - if (auto lrhs = lBin.getRHS().dyn_cast()) { - // rhsConst is known to be positive if a constant. - if (lrhs.getValue() % rhsConst.getValue() == 0) - return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); - } - } - - return nullptr; -} - -static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { - auto lhsConst = lhs.dyn_cast(); - auto rhsConst = rhs.dyn_cast(); - - if (!rhsConst || rhsConst.getValue() < 1) - return nullptr; - - if (lhsConst) - return getAffineConstantExpr( - ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); - - // Fold ceildiv of a multiply with a constant that is a multiple of the - // divisor. Eg: (i * 128) ceildiv 64 = i * 2. - if (rhsConst.getValue() == 1) - return lhs; - - auto lBin = lhs.dyn_cast(); - if (lBin && lBin.getKind() == AffineExprKind::Mul) { - if (auto lrhs = lBin.getRHS().dyn_cast()) { - // rhsConst is known to be positive if a constant. - if (lrhs.getValue() % rhsConst.getValue() == 0) - return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); - } - } - - return nullptr; -} - -static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { - auto lhsConst = lhs.dyn_cast(); - auto rhsConst = rhs.dyn_cast(); - - if (!rhsConst || rhsConst.getValue() < 1) - return nullptr; - - if (lhsConst) - return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()), - lhs.getContext()); - - // Fold modulo of an expression that is known to be a multiple of a constant - // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128) - // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0. - if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0) - return getAffineConstantExpr(0, lhs.getContext()); - - return nullptr; - // TODO(bondhugula): In general, this can be simplified more by using the GCD - // test, or in general using quantifier elimination (add two new variables q - // and r, and eliminate all variables from the linear system other than r. All - // of this can be done through mlir/Analysis/'s FlatAffineConstraints. -} - -/// Return a binary affine op expression with the specified op type and -/// operands: if it doesn't exist, create it and store it; if it is already -/// present, return from the list. The stored expressions are unique: they are -/// constructed and stored in a simplified/canonicalized form. The result after -/// simplification could be any form of affine expression. -AffineExpr AffineBinaryOpExprStorage::get(AffineExprKind kind, AffineExpr lhs, - AffineExpr rhs) { - auto &impl = lhs.getContext()->getImpl(); - - // Check if we already have this affine expression, and return it if we do. - auto keyValue = std::make_tuple((unsigned)kind, lhs, rhs); - - { // Check for an existing instance in read-only mode. - llvm::sys::SmartScopedReader affineLock(impl.affineMutex); - auto cached = impl.affineExprs.find(keyValue); - if (cached != impl.affineExprs.end()) - return cached->second; - } - - // Simplify the expression if possible. - AffineExpr simplified; - switch (kind) { - case AffineExprKind::Add: - simplified = simplifyAdd(lhs, rhs); - break; - case AffineExprKind::Mul: - simplified = simplifyMul(lhs, rhs); - break; - case AffineExprKind::FloorDiv: - simplified = simplifyFloorDiv(lhs, rhs); - break; - case AffineExprKind::CeilDiv: - simplified = simplifyCeilDiv(lhs, rhs); - break; - case AffineExprKind::Mod: - simplified = simplifyMod(lhs, rhs); - break; - default: - llvm_unreachable("unexpected binary affine expr"); - } - - // The simplified one would have already been cached; just return it. - if (simplified) - return simplified; - - // Aquire a writer-lock so that we can safely create the new instance. - llvm::sys::SmartScopedWriter affineLock(impl.affineMutex); - - // Check for an existing instance again here, because another writer thread - // may have already created one. - auto &result = impl.affineExprs.insert({keyValue, nullptr}).first->second; - if (!result) { - // An expression with these operands will already be in the - // simplified/canonical form. Create and store it. - result = new (impl.affineAllocator.Allocate()) - AffineBinaryOpExprStorage{{kind, lhs.getContext()}, lhs, rhs}; - } - return result; -} - -AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, - AffineExpr rhs) { - return AffineBinaryOpExprStorage::get(kind, lhs, rhs); -} - -AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { - auto &impl = context->getImpl(); - - return safeGetOrCreate( - impl.dimExprs, position, impl.affineMutex, [&impl, context, position] { - auto *result = impl.affineAllocator.Allocate(); - // Initialize the memory using placement new. - new (result) - AffineDimExprStorage{{AffineExprKind::DimId, context}, position}; - return result; - }); -} - -AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) { - auto &impl = context->getImpl(); - - return safeGetOrCreate( - impl.symbolExprs, position, impl.affineMutex, [&impl, context, position] { - auto *result = impl.affineAllocator.Allocate(); - // Initialize the memory using placement new. - new (result) AffineSymbolExprStorage{ - {AffineExprKind::SymbolId, context}, position}; - return result; - }); -} - -AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) { - auto &impl = context->getImpl(); - - // Safely get or create an AffineConstantExpr instance. - return safeGetOrCreate(impl.constExprs, constant, impl.affineMutex, [&] { - auto *result = impl.affineAllocator.Allocate(); - return new (result) AffineConstantExprStorage{ - {AffineExprKind::Constant, context}, constant}; - }); -} - //===----------------------------------------------------------------------===// // Integer Sets: these are allocated into the bump pointer, and are immutable. // Unlike AffineMap's, these are uniqued only if they are small. diff --git a/mlir/unittests/IR/SDBMTest.cpp b/mlir/unittests/IR/SDBMTest.cpp index 307bb3869d43..26765134d089 100644 --- a/mlir/unittests/IR/SDBMTest.cpp +++ b/mlir/unittests/IR/SDBMTest.cpp @@ -321,8 +321,8 @@ TEST(SDBMExpr, MatchStripeMulPattern) { // pattern (x floordiv B) * B == x # B. auto cst = getAffineConstantExpr(42, ctx()); auto dim = getAffineDimExpr(0, ctx()); - auto floor = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, dim, cst); - auto mul = getAffineBinaryOpExpr(AffineExprKind::Mul, cst, floor); + auto floor = dim.floorDiv(cst); + auto mul = cst * floor; Optional converted = SDBMStripeExpr::tryConvertAffineExpr(mul); ASSERT_TRUE(converted.hasValue()); EXPECT_TRUE(converted->isa()); @@ -331,10 +331,10 @@ TEST(SDBMExpr, MatchStripeMulPattern) { TEST(SDBMExpr, NonSDBM) { auto d0 = getAffineDimExpr(0, ctx()); auto d1 = getAffineDimExpr(1, ctx()); - auto sum = getAffineBinaryOpExpr(AffineExprKind::Add, d0, d1); + auto sum = d0 + d1; auto c2 = getAffineConstantExpr(2, ctx()); - auto prod = getAffineBinaryOpExpr(AffineExprKind::Mul, d0, c2); - auto ceildiv = getAffineBinaryOpExpr(AffineExprKind::CeilDiv, d1, c2); + auto prod = d0 * c2; + auto ceildiv = d1.ceilDiv(c2); // The following are not valid SDBM expressions: // - a sum of two variables