diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index 6ebb0aef2a34..92bae8a3115e 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -35,10 +35,13 @@ class AffineExpr { public: enum class Kind { Add, - Sub, + // RHS of mul is always a constant or a symbolic expression. Mul, + // RHS of mod is always a constant or a symbolic expression. Mod, + // RHS of floordiv is always a constant or a symbolic expression. FloorDiv, + // RHS of ceildiv is always a constant or a symbolic expression. CeilDiv, /// This is a marker for the last affine binary op. The range of binary @@ -83,9 +86,17 @@ inline raw_ostream &operator<<(raw_ostream &os, const AffineExpr &expr) { return os; } -/// Binary affine expression. +/// Affine binary operation expression. An affine binary operation could be an +/// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is +/// represented through a multiply by -1 and add.) These expressions are always +/// constructed in a simplified form. For eg., the LHS and RHS operands can't +/// both be constants. There are additional canonicalizing rules depending on +/// the op type: see checks in the constructor. class AffineBinaryOpExpr : public AffineExpr { public: + static AffineExpr *get(Kind kind, AffineExpr *lhs, AffineExpr *rhs, + MLIRContext *context); + AffineExpr *getLHS() const { return lhs; } AffineExpr *getRHS() const { return rhs; } @@ -94,10 +105,9 @@ public: return expr->getKind() <= Kind::LAST_AFFINE_BINARY_OP; } -protected: - static AffineExpr *get(Kind kind, AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context); + void print(raw_ostream &os) const; +protected: explicit AffineBinaryOpExpr(Kind kind, AffineExpr *lhs, AffineExpr *rhs); AffineExpr *const lhs; @@ -107,8 +117,6 @@ private: // Simplification prior to construction of binary affine op expressions. static AffineExpr *simplifyAdd(AffineExpr *lhs, AffineExpr *rhs, MLIRContext *context); - static AffineExpr *simplifySub(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context); static AffineExpr *simplifyMul(AffineExpr *lhs, AffineExpr *rhs, MLIRContext *context); static AffineExpr *simplifyFloorDiv(AffineExpr *lhs, AffineExpr *rhs, @@ -119,102 +127,6 @@ private: MLIRContext *context); }; -/// Binary affine add expression. -class AffineAddExpr : public AffineBinaryOpExpr { -public: - static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const AffineExpr *expr) { - return expr->getKind() == Kind::Add; - } - void print(raw_ostream &os) const; - -private: - // No constructor; use AffineBinaryOpExpr::get - AffineAddExpr(AffineExpr *lhs, AffineExpr *rhs) = delete; -}; - -/// Binary affine subtract expression. -class AffineSubExpr : public AffineBinaryOpExpr { -public: - static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const AffineExpr *expr) { - return expr->getKind() == Kind::Sub; - } - void print(raw_ostream &os) const; - -private: - AffineSubExpr(AffineExpr *lhs, AffineExpr *rhs) = delete; -}; - -/// Binary affine multiplication expression. -class AffineMulExpr : public AffineBinaryOpExpr { -public: - static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const AffineExpr *expr) { - return expr->getKind() == Kind::Mul; - } - void print(raw_ostream &os) const; - -private: - AffineMulExpr(AffineExpr *lhs, AffineExpr *rhs) = delete; -}; - -/// Binary affine modulo operation expression. -class AffineModExpr : public AffineBinaryOpExpr { -public: - static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const AffineExpr *expr) { - return expr->getKind() == Kind::Mod; - } - void print(raw_ostream &os) const; - -private: - AffineModExpr(AffineExpr *lhs, AffineExpr *rhs) = delete; -}; - -/// Binary affine floordiv expression. -class AffineFloorDivExpr : public AffineBinaryOpExpr { -public: - static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const AffineExpr *expr) { - return expr->getKind() == Kind::FloorDiv; - } - void print(raw_ostream &os) const; - -private: - AffineFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs) = delete; -}; - -/// Binary affine ceildiv expression. -class AffineCeilDivExpr : public AffineBinaryOpExpr { -public: - static AffineExpr *get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const AffineExpr *expr) { - return expr->getKind() == Kind::CeilDiv; - } - void print(raw_ostream &os) const; - -private: - AffineCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs) = delete; -}; /// A dimensional identifier appearing in an affine expression. /// diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 8d0ee3dfca5c..6bfbaf538b29 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -28,27 +28,19 @@ AffineBinaryOpExpr::AffineBinaryOpExpr(Kind kind, AffineExpr *lhs, switch (kind) { case Kind::Add: assert(!isa(lhs)); - // TODO (more verification) - break; - case Kind::Sub: - // TODO (verification) break; case Kind::Mul: assert(!isa(lhs)); assert(rhs->isSymbolicOrConstant()); - // TODO (more verification) break; case Kind::FloorDiv: assert(rhs->isSymbolicOrConstant()); - // TODO (more verification) break; case Kind::CeilDiv: assert(rhs->isSymbolicOrConstant()); - // TODO (more verification) break; case Kind::Mod: assert(rhs->isSymbolicOrConstant()); - // TODO (more verification) break; default: llvm_unreachable("unexpected binary affine expr"); @@ -67,7 +59,6 @@ bool AffineExpr::isSymbolicOrConstant() const { return true; case Kind::Add: - case Kind::Sub: case Kind::Mul: case Kind::FloorDiv: case Kind::CeilDiv: @@ -87,16 +78,15 @@ bool AffineExpr::isPureAffine() const { case Kind::DimId: case Kind::Constant: return true; - case Kind::Add: - case Kind::Sub: { - auto op = cast(this); + case Kind::Add: { + auto *op = cast(this); return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine(); } case Kind::Mul: { // TODO: Canonicalize the constants in binary operators to the RHS when // possible, allowing this to merge into the next case. - auto op = cast(this); + auto *op = cast(this); return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine() && (isa(op->getLHS()) || isa(op->getRHS())); @@ -104,7 +94,7 @@ bool AffineExpr::isPureAffine() const { case Kind::FloorDiv: case Kind::CeilDiv: case Kind::Mod: { - auto op = cast(this); + auto *op = cast(this); return op->getLHS()->isPureAffine() && isa(op->getRHS()); } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index d8e09b5b2296..8972510af192 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -38,7 +38,7 @@ AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs, if (isa(lhs) || (lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())) - return AffineAddExpr::get(rhs, lhs, context); + return AffineBinaryOpExpr::get(Kind::Add, rhs, lhs, context); return nullptr; // TODO(someone): implement more simplification like x + 0 -> x; (x + 2) + 4 @@ -46,16 +46,6 @@ AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs, // simplifications as opposed to incremental hacks. } -AffineExpr *AffineBinaryOpExpr::simplifySub(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context) { - if (auto *l = dyn_cast(lhs)) - if (auto *r = dyn_cast(rhs)) - return AffineConstantExpr::get(l->getValue() - r->getValue(), context); - - return nullptr; - // TODO(someone): implement more simplification like mentioned for add. -} - /// Simplify a multiply expression. Fold it to a constant when possible, and /// make the symbolic/constant operand the RHS. AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs, @@ -71,7 +61,7 @@ AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs, // constant. (Note that a constant is trivially symbolic). if (!rhs->isSymbolicOrConstant() || isa(lhs)) { // At least one of them has to be symbolic. - return AffineMulExpr::get(rhs, lhs, context); + return AffineBinaryOpExpr::get(Kind::Mul, rhs, lhs, context); } return nullptr; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 570ae4923e51..03d9b1d31bc5 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -661,38 +661,62 @@ void AffineExpr::dump() const { llvm::errs() << "\n"; } -void AffineAddExpr::print(raw_ostream &os) const { - os << "(" << *getLHS() << " + " << *getRHS() << ")"; -} - -void AffineSubExpr::print(raw_ostream &os) const { - os << "(" << *getLHS() << " - " << *getRHS() << ")"; -} - -void AffineMulExpr::print(raw_ostream &os) const { - os << "(" << *getLHS() << " * " << *getRHS() << ")"; -} - -void AffineModExpr::print(raw_ostream &os) const { - os << "(" << *getLHS() << " mod " << *getRHS() << ")"; -} - -void AffineFloorDivExpr::print(raw_ostream &os) const { - os << "(" << *getLHS() << " floordiv " << *getRHS() << ")"; -} - -void AffineCeilDivExpr::print(raw_ostream &os) const { - os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")"; -} - void AffineSymbolExpr::print(raw_ostream &os) const { - os << "s" << getPosition(); + os << 's' << getPosition(); } -void AffineDimExpr::print(raw_ostream &os) const { os << "d" << getPosition(); } +void AffineDimExpr::print(raw_ostream &os) const { os << 'd' << getPosition(); } void AffineConstantExpr::print(raw_ostream &os) const { os << getValue(); } +static void printAdd(const AffineBinaryOpExpr *addExpr, raw_ostream &os) { + os << '(' << *addExpr->getLHS(); + + // Pretty print addition to a product that has a negative operand as a + // subtraction. + if (auto *rhs = dyn_cast(addExpr->getRHS())) { + if (rhs->getKind() == AffineExpr::Kind::Mul) { + if (auto *rrhs = dyn_cast(rhs->getRHS())) { + if (rrhs->getValue() < 0) { + os << " - (" << *rhs->getLHS() << " * " << -rrhs->getValue() << "))"; + return; + } + } + } + } + + // Pretty print addition to a negative number as a subtraction. + if (auto *rhs = dyn_cast(addExpr->getRHS())) { + if (rhs->getValue() < 0) { + os << " - " << -rhs->getValue() << ")"; + return; + } + } + + os << " + " << *addExpr->getRHS() << ")"; +} + +void AffineBinaryOpExpr::print(raw_ostream &os) const { + switch (getKind()) { + case Kind::Add: + return printAdd(this, os); + case Kind::Mul: + os << "(" << *getLHS() << " * " << *getRHS() << ")"; + return; + case Kind::FloorDiv: + os << "(" << *getLHS() << " floordiv " << *getRHS() << ")"; + return; + case Kind::CeilDiv: + os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")"; + return; + case Kind::Mod: + os << "(" << *getLHS() << " mod " << *getRHS() << ")"; + return; + default: + llvm_unreachable("unexpected affine binary op expression"); + } +} + void AffineExpr::print(raw_ostream &os) const { switch (getKind()) { case Kind::SymbolId: @@ -702,17 +726,11 @@ void AffineExpr::print(raw_ostream &os) const { case Kind::Constant: return cast(this)->print(os); case Kind::Add: - return cast(this)->print(os); - case Kind::Sub: - return cast(this)->print(os); case Kind::Mul: - return cast(this)->print(os); case Kind::FloorDiv: - return cast(this)->print(os); case Kind::CeilDiv: - return cast(this)->print(os); case Kind::Mod: - return cast(this)->print(os); + return cast(this)->print(os); } } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 3d7e023c0ac6..dc5b8e2433d7 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -121,27 +121,23 @@ AffineConstantExpr *Builder::getConstantExpr(int64_t constant) { } AffineExpr *Builder::getAddExpr(AffineExpr *lhs, AffineExpr *rhs) { - return AffineAddExpr::get(lhs, rhs, context); -} - -AffineExpr *Builder::getSubExpr(AffineExpr *lhs, AffineExpr *rhs) { - return AffineSubExpr::get(lhs, rhs, context); + return AffineBinaryOpExpr::get(AffineExpr::Kind::Add, lhs, rhs, context); } AffineExpr *Builder::getMulExpr(AffineExpr *lhs, AffineExpr *rhs) { - return AffineMulExpr::get(lhs, rhs, context); + return AffineBinaryOpExpr::get(AffineExpr::Kind::Mul, lhs, rhs, context); } AffineExpr *Builder::getModExpr(AffineExpr *lhs, AffineExpr *rhs) { - return AffineModExpr::get(lhs, rhs, context); + return AffineBinaryOpExpr::get(AffineExpr::Kind::Mod, lhs, rhs, context); } AffineExpr *Builder::getFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs) { - return AffineFloorDivExpr::get(lhs, rhs, context); + return AffineBinaryOpExpr::get(AffineExpr::Kind::FloorDiv, lhs, rhs, context); } AffineExpr *Builder::getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs) { - return AffineCeilDivExpr::get(lhs, rhs, context); + return AffineBinaryOpExpr::get(AffineExpr::Kind::CeilDiv, lhs, rhs, context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index df3d01aad8d3..8d2de1097613 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -684,9 +684,6 @@ AffineExpr *AffineBinaryOpExpr::get(AffineExpr::Kind kind, AffineExpr *lhs, case Kind::Add: simplified = AffineBinaryOpExpr::simplifyAdd(lhs, rhs, context); break; - case Kind::Sub: - simplified = AffineBinaryOpExpr::simplifySub(lhs, rhs, context); - break; case Kind::Mul: simplified = AffineBinaryOpExpr::simplifyMul(lhs, rhs, context); break; @@ -720,36 +717,6 @@ AffineExpr *AffineBinaryOpExpr::get(AffineExpr::Kind kind, AffineExpr *lhs, return result; } -AffineExpr *AffineAddExpr::get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context) { - return AffineBinaryOpExpr::get(Kind::Add, lhs, rhs, context); -} - -AffineExpr *AffineSubExpr::get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context) { - return AffineBinaryOpExpr::get(Kind::Sub, lhs, rhs, context); -} - -AffineExpr *AffineMulExpr::get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context) { - return AffineBinaryOpExpr::get(Kind::Mul, lhs, rhs, context); -} - -AffineExpr *AffineFloorDivExpr::get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context) { - return AffineBinaryOpExpr::get(Kind::FloorDiv, lhs, rhs, context); -} - -AffineExpr *AffineCeilDivExpr::get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context) { - return AffineBinaryOpExpr::get(Kind::CeilDiv, lhs, rhs, context); -} - -AffineExpr *AffineModExpr::get(AffineExpr *lhs, AffineExpr *rhs, - MLIRContext *context) { - return AffineBinaryOpExpr::get(Kind::Mod, lhs, rhs, context); -} - AffineDimExpr *AffineDimExpr::get(unsigned position, MLIRContext *context) { // TODO(bondhugula): complete this // FIXME: this should be POD diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 98fd7162d2d5..7b7ee89f28c6 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -758,7 +758,8 @@ AffineExpr *AffineMapParser::getBinaryAffineOpExpr(AffineLowPrecOp op, case AffineLowPrecOp::Add: return builder.getAddExpr(lhs, rhs); case AffineLowPrecOp::Sub: - return builder.getSubExpr(lhs, rhs); + return builder.getAddExpr( + lhs, builder.getMulExpr(rhs, builder.getConstantExpr(-1))); case AffineLowPrecOp::LNoOp: llvm_unreachable("can't create affine expression for null low prec op"); return nullptr; diff --git a/mlir/test/IR/parser-affine-map.mlir b/mlir/test/IR/parser-affine-map.mlir index 50f2bd731419..030b86dae9ce 100644 --- a/mlir/test/IR/parser-affine-map.mlir +++ b/mlir/test/IR/parser-affine-map.mlir @@ -81,10 +81,10 @@ // CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 ceildiv 5)) #map26 = (i, j) [s0, s1] -> (i, j ceildiv 5) -// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - d1) - 5)) +// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - (d1 * 1)) - 5)) #map29 = (i, j) [s0, s1] -> (i, i - j - 5) -// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - (d1 * s1)) + 2)) +// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - ((d1 * s1) * 1)) + 2)) #map30 = (i, j) [M, N] -> (i, i - N*j + 2) // CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * -5), (d1 * -3), -2, ((d0 + d1) * -1), (s0 * -1)) @@ -238,4 +238,4 @@ extfunc @f40(memref<2x4xi8, #map40, 1>) extfunc @f41(memref<2x4xi8, #map41, 1>) // CHECK: extfunc @f42(memref<2x4xi8, #map{{[0-9]+}}, 1>) -extfunc @f42(memref<2x4xi8, #map42, 1>) \ No newline at end of file +extfunc @f42(memref<2x4xi8, #map42, 1>)