diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 3330f48f2eff..95e54141684d 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -16,8 +16,8 @@ // ============================================================================= // // This header file defines prototypes for methods that perform analysis -// involving affine structures (AffineExpr, AffineMap, IntegerSet, etc.) and -// other IR structures that in turn use these. +// involving affine structures (AffineExprClass, AffineMap, IntegerSet, etc.) +// and other IR structures that in turn use these. // //===----------------------------------------------------------------------===// @@ -31,11 +31,11 @@ namespace mlir { namespace detail { -class AffineExpr; +class AffineExprClass; } // namespace detail -template class AffineExprBaseRef; -using AffineExprRef = AffineExprBaseRef; +template class AffineExprBase; +using AffineExpr = AffineExprBase; class MLIRContext; class MLValue; class OperationStmt; @@ -44,8 +44,8 @@ class OperationStmt; /// simple analysis. This has complexity linear in the number of nodes in /// 'expr'. Returns the simplified expression, which is the same as the input // expression if it can't be simplified. -AffineExprRef simplifyAffineExpr(AffineExprRef expr, unsigned numDims, - unsigned numSymbols); +AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols); /// Returns the sequence of AffineApplyOp OperationStmts operation in /// 'affineApplyOps', which are reachable via a search starting from 'operands', diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index f112600e5328..8f5a8aa72a0d 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -44,8 +44,8 @@ struct MutableAffineMap { public: MutableAffineMap(AffineMap *map, MLIRContext *context); - AffineExprRef getResult(unsigned idx) const { return results[idx]; } - void setResult(unsigned idx, AffineExprRef result) { results[idx] = result; } + AffineExpr getResult(unsigned idx) const { return results[idx]; } + void setResult(unsigned idx, AffineExpr result) { results[idx] = result; } unsigned getNumResults() const { return results.size(); } unsigned getNumDims() const { return numDims; } void setNumDims(unsigned d) { numDims = d; } @@ -66,11 +66,12 @@ public: private: // Same meaning as AffineMap's fields. - SmallVector results; - SmallVector rangeSizes; + SmallVector results; + SmallVector rangeSizes; unsigned numDims; unsigned numSymbols; - /// A pointer to the IR's context to store all newly created AffineExpr's. + /// A pointer to the IR's context to store all newly created + /// AffineExprClass's. MLIRContext *context; }; @@ -96,9 +97,10 @@ private: unsigned numDims; unsigned numSymbols; - SmallVector constraints; + SmallVector constraints; SmallVector eqFlags; - /// A pointer to the IR's context to store all newly created AffineExpr's. + /// A pointer to the IR's context to store all newly created + /// AffineExprClass's. MLIRContext *context; }; @@ -283,7 +285,7 @@ public: return ArrayRef(&inequalities[idx * getNumCols()], getNumCols()); } - AffineExprRef toAffineExpr(unsigned idx, MLIRContext *context); + AffineExpr toAffineExpr(unsigned idx, MLIRContext *context); void addInequality(ArrayRef inEq); void addEquality(ArrayRef eq); diff --git a/mlir/include/mlir/Analysis/HyperRectangularSet.h b/mlir/include/mlir/Analysis/HyperRectangularSet.h index e1f2599359d8..ad2b2560dc4f 100644 --- a/mlir/include/mlir/Analysis/HyperRectangularSet.h +++ b/mlir/include/mlir/Analysis/HyperRectangularSet.h @@ -45,7 +45,7 @@ class HyperRectangleList; /// A list of affine bounds. // Not using a MutableAffineMap here since numSymbols is the same as the // containing HyperRectangularSet's numSymbols, and its numDims is 0. -typedef SmallVector AffineBoundExprList; +typedef SmallVector AffineBoundExprList; /// A HyperRectangularSet is a symbolic set of integer points contained in a /// hyper-rectangular space. It supports set manipulation operations @@ -93,9 +93,8 @@ public: getFromFlatAffineConstraints(const FlatAffineConstraints &cst); HyperRectangularSet(unsigned numDims, unsigned numSymbols, - ArrayRef> lbs, - ArrayRef> ubs, - MLIRContext *context, + ArrayRef> lbs, + ArrayRef> ubs, MLIRContext *context, IntegerSet *symbolContext = nullptr); unsigned getNumDims() const { return numDims; } @@ -128,10 +127,10 @@ public: bool empty() const; /// Add a lower bound expression to dimension position 'idx'. - void addLowerBoundExpr(unsigned idx, AffineExprRef expr); + void addLowerBoundExpr(unsigned idx, AffineExpr expr); /// Add an upper bound expression to dimension position 'idx'. - void addUpperBoundExpr(unsigned idx, AffineExprRef expr); + void addUpperBoundExpr(unsigned idx, AffineExpr expr); /// Clear this set's context, i.e., make it the universal set. void clearContext() { context.clear(); } diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index 04f826ce5a07..85ca5996ab51 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -28,17 +28,17 @@ namespace mlir { namespace detail { -class AffineExpr; +class AffineExprClass; } // namespace detail -template class AffineExprBaseRef; -using AffineExprRef = AffineExprBaseRef; +template class AffineExprBase; +using AffineExpr = AffineExprBase; class ForStmt; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count /// expression is simplified before returning. -AffineExprRef getTripCountExpr(const ForStmt &forStmt); +AffineExpr getTripCountExpr(const ForStmt &forStmt); /// Returns the trip count of the loop if it's a constant, None otherwise. This /// uses affine expression analysis and is able to determine constant trip count diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index fdfb6b63defc..5e1359234d46 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -34,11 +34,11 @@ class MLIRContext; namespace detail { -class AffineExpr; -class AffineBinaryOpExpr; -class AffineDimExpr; -class AffineSymbolExpr; -class AffineConstantExpr; +class AffineExprClass; +class AffineBinaryOpExprClass; +class AffineDimExprClass; +class AffineSymbolExprClass; +class AffineConstantExprClass; } // namespace detail @@ -65,36 +65,34 @@ enum class AffineExprKind { SymbolId, }; -/// Helper structure to build AffineExpr with intuitive operators in order to -/// operate on chainable, lightweight, immutable value types instead of pointer -/// types. -/// TODO(ntv): Remove all redundant MLIRContext* arguments through the API -/// TODO(ntv): Add extra out-of-class operators for int op AffineExprBaseRef -/// TODO(ntv): Rename +/// Helper structure to build AffineExprClass with intuitive operators in order +/// to operate on chainable, lightweight, immutable value types instead of +/// pointer types. +/// TODO(ntv): Add extra out-of-class operators for int op AffineExprBase /// TODO(ntv): pointer pair -template class AffineExprBaseRef { +template class AffineExprBase { public: - typedef AffineExprBaseRef TemplateType; + typedef AffineExprBase TemplateType; typedef AffineExprType ImplType; - AffineExprBaseRef() : expr(nullptr) {} - /* implicit */ AffineExprBaseRef(const AffineExprType *expr) + AffineExprBase() : expr(nullptr) {} + /* implicit */ AffineExprBase(const AffineExprType *expr) : expr(const_cast(expr)) {} - AffineExprBaseRef(const AffineExprBaseRef &other) : expr(other.expr) {} - AffineExprBaseRef &operator=(AffineExprBaseRef other) { + AffineExprBase(const AffineExprBase &other) : expr(other.expr) {} + AffineExprBase &operator=(AffineExprBase other) { expr = other.expr; return *this; } - bool operator==(AffineExprBaseRef other) const { return expr == other.expr; } + bool operator==(AffineExprBase other) const { return expr == other.expr; } explicit operator AffineExprType *() const { return const_cast(expr); } - /* implicit */ operator AffineExprBaseRef() const { - return const_cast( - static_cast(expr)); + /* implicit */ operator AffineExprBase() const { + return const_cast( + static_cast(expr)); } explicit operator bool() const { return expr; } @@ -114,47 +112,51 @@ public: return U(llvm::cast(const_cast(this->expr))); } - AffineExprBaseRef operator+(int64_t v) const; - AffineExprBaseRef operator+(AffineExprBaseRef other) const; - AffineExprBaseRef operator-() const; - AffineExprBaseRef operator-(int64_t v) const; - AffineExprBaseRef operator-(AffineExprBaseRef other) const; - AffineExprBaseRef operator*(int64_t v) const; - AffineExprBaseRef operator*(AffineExprBaseRef other) const; - AffineExprBaseRef floorDiv(uint64_t v) const; - AffineExprBaseRef floorDiv(AffineExprBaseRef other) const; - AffineExprBaseRef ceilDiv(uint64_t v) const; - AffineExprBaseRef ceilDiv(AffineExprBaseRef other) const; - AffineExprBaseRef operator%(uint64_t v) const; - AffineExprBaseRef operator%(AffineExprBaseRef other) const; + AffineExprBase operator+(int64_t v) const; + AffineExprBase operator+(AffineExprBase other) const; + AffineExprBase operator-() const; + AffineExprBase operator-(int64_t v) const; + AffineExprBase operator-(AffineExprBase other) const; + AffineExprBase operator*(int64_t v) const; + AffineExprBase operator*(AffineExprBase other) const; + AffineExprBase floorDiv(uint64_t v) const; + AffineExprBase floorDiv(AffineExprBase other) const; + AffineExprBase ceilDiv(uint64_t v) const; + AffineExprBase ceilDiv(AffineExprBase other) const; + AffineExprBase operator%(uint64_t v) const; + AffineExprBase operator%(AffineExprBase other) const; - friend ::llvm::hash_code hash_value(AffineExprBaseRef arg); + friend ::llvm::hash_code hash_value(AffineExprBase arg); private: AffineExprType *expr; }; -using AffineExprRef = AffineExprBaseRef; -using AffineBinaryOpExprRef = AffineExprBaseRef; -using AffineDimExprRef = AffineExprBaseRef; -using AffineSymbolExprRef = AffineExprBaseRef; -using AffineConstantExprRef = AffineExprBaseRef; +using AffineExpr = AffineExprBase; +using AffineBinaryOpExpr = AffineExprBase; +using AffineDimExpr = AffineExprBase; +using AffineSymbolExpr = AffineExprBase; +using AffineConstantExpr = AffineExprBase; -// Make AffineExprRef hashable. -inline ::llvm::hash_code hash_value(AffineExprRef arg) { - return ::llvm::hash_value(static_cast(arg.expr)); +AffineExpr operator+(int64_t val, AffineExpr expr); +AffineExpr operator-(int64_t val, AffineExpr expr); +AffineExpr operator*(int64_t val, AffineExpr expr); + +// Make AffineExpr hashable. +inline ::llvm::hash_code hash_value(AffineExpr arg) { + return ::llvm::hash_value(static_cast(arg.expr)); } // These free functions allow clients of the API to not use classes in detail. -AffineExprRef getAffineDimExpr(unsigned position, MLIRContext *context); -AffineExprRef getAffineSymbolExpr(unsigned position, MLIRContext *context); -AffineExprRef getAffineConstantExpr(int64_t constant, MLIRContext *context); +AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context); +AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context); +AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context); namespace detail { /// A one-dimensional affine expression. /// AffineExpression's are immutable (like Type's) -class AffineExpr { +class AffineExprClass { public: /// Return the classification for this type. AffineExprKind getKind() { return kind; } @@ -179,20 +181,20 @@ public: MLIRContext *getContext(); protected: - explicit AffineExpr(AffineExprKind kind, MLIRContext *context) + explicit AffineExprClass(AffineExprKind kind, MLIRContext *context) : kind(kind), context(context) {} - ~AffineExpr() {} + ~AffineExprClass() {} private: - AffineExpr(const AffineExpr &) = delete; - void operator=(const AffineExpr &) = delete; + AffineExprClass(const AffineExprClass &) = delete; + void operator=(const AffineExprClass &) = delete; /// Classification of the subclass const AffineExprKind kind; MLIRContext *context; }; -inline raw_ostream &operator<<(raw_ostream &os, AffineExprRef &expr) { +inline raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr) { expr->print(os); return os; } @@ -203,62 +205,50 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineExprRef &expr) { /// constructed in a simplified form. For eg., the LHS and RHS operands can't /// both be constants. There are additional canonicalizing rules depending on /// the op type: see checks in the constructor. -class AffineBinaryOpExpr : public AffineExpr { +class AffineBinaryOpExprClass : public AffineExprClass { public: - static AffineExprRef get(AffineExprKind kind, AffineExprRef lhs, - AffineExprRef rhs, MLIRContext *context); - static AffineExprRef getAdd(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context) { - return get(AffineExprKind::Add, lhs, rhs, context); + static AffineExpr get(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs); + static AffineExpr getAdd(AffineExpr lhs, AffineExpr rhs) { + return get(AffineExprKind::Add, lhs, rhs); } - static AffineExprRef getAdd(AffineExprRef expr, int64_t rhs, - MLIRContext *context); - static AffineExprRef getSub(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context); + static AffineExpr getAdd(AffineExpr expr, int64_t rhs); + static AffineExpr getSub(AffineExpr lhs, AffineExpr rhs); - static AffineExprRef getMul(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context) { - return get(AffineExprKind::Mul, lhs, rhs, context); + static AffineExpr getMul(AffineExpr lhs, AffineExpr rhs) { + return get(AffineExprKind::Mul, lhs, rhs); } - static AffineExprRef getMul(AffineExprRef expr, int64_t rhs, - MLIRContext *context); - static AffineExprRef getFloorDiv(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context) { - return get(AffineExprKind::FloorDiv, lhs, rhs, context); + static AffineExpr getMul(AffineExpr expr, int64_t rhs); + static AffineExpr getFloorDiv(AffineExpr lhs, AffineExpr rhs) { + return get(AffineExprKind::FloorDiv, lhs, rhs); } - static AffineExprRef getFloorDiv(AffineExprRef lhs, uint64_t rhs, - MLIRContext *context); - static AffineExprRef getCeilDiv(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context) { - return get(AffineExprKind::CeilDiv, lhs, rhs, context); + static AffineExpr getFloorDiv(AffineExpr lhs, uint64_t rhs); + static AffineExpr getCeilDiv(AffineExpr lhs, AffineExpr rhs) { + return get(AffineExprKind::CeilDiv, lhs, rhs); } - static AffineExprRef getCeilDiv(AffineExprRef lhs, uint64_t rhs, - MLIRContext *context); - static AffineExprRef getMod(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context) { - return get(AffineExprKind::Mod, lhs, rhs, context); + static AffineExpr getCeilDiv(AffineExpr lhs, uint64_t rhs); + static AffineExpr getMod(AffineExpr lhs, AffineExpr rhs) { + return get(AffineExprKind::Mod, lhs, rhs); } - static AffineExprRef getMod(AffineExprRef lhs, uint64_t rhs, - MLIRContext *context); + static AffineExpr getMod(AffineExpr lhs, uint64_t rhs); - AffineExprRef getLHS() { return lhs; } - AffineExprRef getRHS() { return rhs; } + AffineExpr getLHS() { return lhs; } + AffineExpr getRHS() { return rhs; } /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const AffineExpr *expr) { - return const_cast(expr)->getKind() <= + static bool classof(const AffineExprClass *expr) { + return const_cast(expr)->getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP; } protected: - explicit AffineBinaryOpExpr(AffineExprKind kind, AffineExprRef lhs, - AffineExprRef rhs, MLIRContext *context); + explicit AffineBinaryOpExprClass(AffineExprKind kind, AffineExpr lhs, + AffineExpr rhs); - const AffineExprRef lhs; - const AffineExprRef rhs; + const AffineExpr lhs; + const AffineExpr rhs; private: - ~AffineBinaryOpExpr() = delete; + ~AffineBinaryOpExprClass() = delete; }; /// A dimensional identifier appearing in an affine expression. @@ -266,25 +256,26 @@ private: /// This is a POD type of int size; so it should be passed around by /// value. The underlying data is owned by MLIRContext and is thus immortal for /// almost all clients. -class AffineDimExpr : public AffineExpr { +class AffineDimExprClass : public AffineExprClass { public: - static AffineExprBaseRef get(unsigned position, - MLIRContext *context); + static AffineExprBase get(unsigned position, + MLIRContext *context); unsigned getPosition() { return position; } /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const AffineExpr *expr) { - return const_cast(expr)->getKind() == AffineExprKind::DimId; + static bool classof(const AffineExprClass *expr) { + return const_cast(expr)->getKind() == + AffineExprKind::DimId; } - friend AffineExprRef mlir::getAffineDimExpr(unsigned position, - MLIRContext *context); + friend AffineExpr mlir::getAffineDimExpr(unsigned position, + MLIRContext *context); private: - ~AffineDimExpr() = delete; - explicit AffineDimExpr(unsigned position, MLIRContext *context) - : AffineExpr(AffineExprKind::DimId, context), position(position) {} + ~AffineDimExprClass() = delete; + explicit AffineDimExprClass(unsigned position, MLIRContext *context) + : AffineExprClass(AffineExprKind::DimId, context), position(position) {} /// Position of this identifier in the argument list. unsigned position; @@ -295,52 +286,54 @@ private: /// This is a POD type of int size, so it should be passed around by /// value. The underlying data is owned by MLIRContext and is thus immortal for /// almost all clients. -class AffineSymbolExpr : public AffineExpr { +class AffineSymbolExprClass : public AffineExprClass { public: - static AffineExprBaseRef get(unsigned position, - MLIRContext *context); + static AffineExprBase get(unsigned position, + MLIRContext *context); unsigned getPosition() { return position; } /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const AffineExpr *expr) { - return const_cast(expr)->getKind() == + static bool classof(const AffineExprClass *expr) { + return const_cast(expr)->getKind() == AffineExprKind::SymbolId; } - friend AffineExprRef mlir::getAffineSymbolExpr(unsigned position, - MLIRContext *context); + friend AffineExpr mlir::getAffineSymbolExpr(unsigned position, + MLIRContext *context); private: - ~AffineSymbolExpr() = delete; - explicit AffineSymbolExpr(unsigned position, MLIRContext *context) - : AffineExpr(AffineExprKind::SymbolId, context), position(position) {} + ~AffineSymbolExprClass() = delete; + explicit AffineSymbolExprClass(unsigned position, MLIRContext *context) + : AffineExprClass(AffineExprKind::SymbolId, context), position(position) { + } /// Position of this identifier in the symbol list. unsigned position; }; /// An integer constant appearing in affine expression. -class AffineConstantExpr : public AffineExpr { +class AffineConstantExprClass : public AffineExprClass { public: - static AffineExprBaseRef get(int64_t constant, - MLIRContext *context); + static AffineExprBase get(int64_t constant, + MLIRContext *context); int64_t getValue() { return constant; } /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const AffineExpr *expr) { - return const_cast(expr)->getKind() == + static bool classof(const AffineExprClass *expr) { + return const_cast(expr)->getKind() == AffineExprKind::Constant; } - friend AffineExprRef mlir::getAffineConstantExpr(int64_t constant, - MLIRContext *context); + friend AffineExpr mlir::getAffineConstantExpr(int64_t constant, + MLIRContext *context); private: - ~AffineConstantExpr() = delete; - explicit AffineConstantExpr(int64_t constant, MLIRContext *context) - : AffineExpr(AffineExprKind::Constant, context), constant(constant) {} + ~AffineConstantExprClass() = delete; + explicit AffineConstantExprClass(int64_t constant, MLIRContext *context) + : AffineExprClass(AffineExprKind::Constant, context), constant(constant) { + } // The constant. int64_t constant; @@ -351,22 +344,20 @@ private: namespace llvm { -// AffineExprRef hash just like pointers -template <> struct DenseMapInfo { - static mlir::AffineExprRef getEmptyKey() { +// AffineExpr hash just like pointers +template <> struct DenseMapInfo { + static mlir::AffineExpr getEmptyKey() { auto pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::AffineExprRef( - static_cast(pointer)); + return mlir::AffineExpr(static_cast(pointer)); } - static mlir::AffineExprRef getTombstoneKey() { + static mlir::AffineExpr getTombstoneKey() { auto pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::AffineExprRef( - static_cast(pointer)); + return mlir::AffineExpr(static_cast(pointer)); } - static unsigned getHashValue(mlir::AffineExprRef val) { + static unsigned getHashValue(mlir::AffineExpr val) { return mlir::hash_value(val); } - static bool isEqual(mlir::AffineExprRef LHS, mlir::AffineExprRef RHS) { + static bool isEqual(mlir::AffineExpr LHS, mlir::AffineExpr RHS) { return LHS == RHS; } }; diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h index 9234df0d2eb4..d121093c2e23 100644 --- a/mlir/include/mlir/IR/AffineExprVisitor.h +++ b/mlir/include/mlir/IR/AffineExprVisitor.h @@ -1,4 +1,4 @@ -//===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===// +//===- AffineExprVisitor.h - MLIR AffineExprClass Visitor Class -*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= // -// This file defines the AffineExpr visitor class. +// This file defines the AffineExprClass visitor class. // //===----------------------------------------------------------------------===// @@ -26,9 +26,9 @@ namespace mlir { -/// Base class for AffineExpr visitors/walkers. +/// Base class for AffineExprClass visitors/walkers. /// -/// AffineExpr visitors are used when you want to perform different actions +/// AffineExprClass visitors are used when you want to perform different actions /// for different kinds of AffineExprs without having to use lots of casts /// and a big switch statement. /// @@ -46,7 +46,7 @@ namespace mlir { /// struct DimExprCounter : public AffineExprVisitor { /// unsigned numDimExprs; /// DimExprCounter() : numDimExprs(0) {} -/// void visitAffineDimExpr(AffineDimExprRef expr) { ++numDimExprs; } +/// void visitAffineDimExpr(AffineDimExpr expr) { ++numDimExprs; } /// }; /// /// And this class would be used like this: @@ -56,13 +56,14 @@ namespace mlir { /// /// AffineExprVisitor provides visit methods for the following binary affine /// op expressions: -/// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr, AffineBinaryModOpExpr, -/// AffineBinaryFloorDivOpExpr, AffineBinaryCeilDivOpExpr. -/// Note that default implementations of these methods will call the general -/// AffineBinaryOpExpr method. +/// AffineBinaryAddOpExprClass, AffineBinaryMulOpExprClass, +/// AffineBinaryModOpExprClass, AffineBinaryFloorDivOpExprClass, +/// AffineBinaryCeilDivOpExpr. Note that default implementations of these +/// methods will call the general AffineBinaryOpExprClass method. /// /// In addition, visit methods are provided for the following affine -// expressions: AffineConstantExpr, AffineDimExpr, and AffineSymbolExpr. +// expressions: AffineConstantExprClass, AffineDimExprClass, and +// AffineSymbolExpr. /// /// Note that if you don't implement visitXXX for some affine expression type, /// the visitXXX method for Statement superclass will be invoked. @@ -77,88 +78,88 @@ template class AffineExprVisitor { // Interface code - This is the public interface of the AffineExprVisitor // that you use to visit affine expressions... public: - // Function to walk an AffineExpr (in post order). - RetTy walkPostOrder(AffineExprRef expr) { + // Function to walk an AffineExprClass (in post order). + RetTy walkPostOrder(AffineExpr expr) { static_assert(std::is_base_of::value, "Must instantiate with a derived type of AffineExprVisitor"); switch (expr->getKind()) { case AffineExprKind::Add: { - auto binOpExpr = expr.cast(); + auto binOpExpr = expr.cast(); walkOperandsPostOrder(binOpExpr); return static_cast(this)->visitAddExpr(binOpExpr); } case AffineExprKind::Mul: { - auto binOpExpr = expr.cast(); + auto binOpExpr = expr.cast(); walkOperandsPostOrder(binOpExpr); return static_cast(this)->visitMulExpr(binOpExpr); } case AffineExprKind::Mod: { - auto binOpExpr = expr.cast(); + auto binOpExpr = expr.cast(); walkOperandsPostOrder(binOpExpr); return static_cast(this)->visitModExpr(binOpExpr); } case AffineExprKind::FloorDiv: { - auto binOpExpr = expr.cast(); + auto binOpExpr = expr.cast(); walkOperandsPostOrder(binOpExpr); return static_cast(this)->visitFloorDivExpr(binOpExpr); } case AffineExprKind::CeilDiv: { - auto binOpExpr = expr.cast(); + auto binOpExpr = expr.cast(); walkOperandsPostOrder(binOpExpr); return static_cast(this)->visitCeilDivExpr(binOpExpr); } case AffineExprKind::Constant: return static_cast(this)->visitConstantExpr( - expr.cast()); + expr.cast()); case AffineExprKind::DimId: return static_cast(this)->visitDimExpr( - expr.cast()); + expr.cast()); case AffineExprKind::SymbolId: return static_cast(this)->visitSymbolExpr( - expr.cast()); + expr.cast()); } } // Function to visit an AffineExpr. - RetTy visit(AffineExprRef expr) { + RetTy visit(AffineExpr expr) { static_assert(std::is_base_of::value, "Must instantiate with a derived type of AffineExprVisitor"); switch (expr->getKind()) { case AffineExprKind::Add: { - auto binOpExpr = expr.cast(); + auto binOpExpr = expr.cast(); return static_cast(this)->visitAddExpr(binOpExpr); } case AffineExprKind::Mul: { - auto binOpExpr = expr.cast(); + auto binOpExpr = expr.cast(); return static_cast(this)->visitMulExpr(binOpExpr); } case AffineExprKind::Mod: { - auto binOpExpr = expr.cast(); + auto binOpExpr = expr.cast(); return static_cast(this)->visitModExpr(binOpExpr); } case AffineExprKind::FloorDiv: { - auto binOpExpr = expr.cast(); + auto binOpExpr = expr.cast(); return static_cast(this)->visitFloorDivExpr(binOpExpr); } case AffineExprKind::CeilDiv: { - auto binOpExpr = expr.cast(); + auto binOpExpr = expr.cast(); return static_cast(this)->visitCeilDivExpr(binOpExpr); } case AffineExprKind::Constant: return static_cast(this)->visitConstantExpr( - expr.cast()); + expr.cast()); case AffineExprKind::DimId: return static_cast(this)->visitDimExpr( - expr.cast()); + expr.cast()); case AffineExprKind::SymbolId: return static_cast(this)->visitSymbolExpr( - expr.cast()); + expr.cast()); } } @@ -172,29 +173,29 @@ public: // Default visit methods. Note that the default op-specific binary op visit // methods call the general visitAffineBinaryOpExpr visit method. - void visitAffineBinaryOpExpr(AffineBinaryOpExprRef expr) {} - void visitAddExpr(AffineBinaryOpExprRef expr) { + void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {} + void visitAddExpr(AffineBinaryOpExpr expr) { static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitMulExpr(AffineBinaryOpExprRef expr) { + void visitMulExpr(AffineBinaryOpExpr expr) { static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitModExpr(AffineBinaryOpExprRef expr) { + void visitModExpr(AffineBinaryOpExpr expr) { static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitFloorDivExpr(AffineBinaryOpExprRef expr) { + void visitFloorDivExpr(AffineBinaryOpExpr expr) { static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitCeilDivExpr(AffineBinaryOpExprRef expr) { + void visitCeilDivExpr(AffineBinaryOpExpr expr) { static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitConstantExpr(AffineConstantExprRef expr) {} - void visitAffineDimExpr(AffineDimExprRef expr) {} - void visitAffineSymbolExpr(AffineSymbolExprRef expr) {} + void visitConstantExpr(AffineConstantExpr expr) {} + void visitAffineDimExpr(AffineDimExpr expr) {} + void visitAffineSymbolExpr(AffineSymbolExpr expr) {} private: // Walk the operands - each operand is itself walked in post order. - void walkOperandsPostOrder(AffineBinaryOpExprRef expr) { + void walkOperandsPostOrder(AffineBinaryOpExpr expr) { walkPostOrder(expr->getLHS()); walkPostOrder(expr->getRHS()); } diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index 91dfe022c44e..d91dbe266cc6 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -32,11 +32,11 @@ namespace mlir { namespace detail { -class AffineExpr; +class AffineExprClass; } // namespace detail -template class AffineExprBaseRef; -using AffineExprRef = AffineExprBaseRef; +template class AffineExprBase; +using AffineExpr = AffineExprBase; class Attribute; class MLIRContext; @@ -48,9 +48,8 @@ class MLIRContext; class AffineMap { public: static AffineMap *get(unsigned dimCount, unsigned symbolCount, - ArrayRef results, - ArrayRef rangeSizes, - MLIRContext *context); + ArrayRef results, + ArrayRef rangeSizes); /// Returns a single constant result affine map. static AffineMap *getConstantMap(int64_t val, MLIRContext *context); @@ -81,11 +80,11 @@ public: unsigned getNumResults() { return numResults; } unsigned getNumInputs() { return numDims + numSymbols; } - ArrayRef getResults() { return results; } + ArrayRef getResults() { return results; } - AffineExprRef getResult(unsigned idx); + AffineExpr getResult(unsigned idx); - ArrayRef getRangeSizes() { return rangeSizes; } + ArrayRef getRangeSizes() { return rangeSizes; } /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. Returns false if the folding happens, @@ -95,8 +94,7 @@ public: private: AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults, - ArrayRef results, - ArrayRef rangeSizes); + ArrayRef results, ArrayRef rangeSizes); AffineMap(const AffineMap &) = delete; void operator=(const AffineMap &) = delete; @@ -107,11 +105,11 @@ private: /// The affine expressions for this (multi-dimensional) map. /// TODO: use trailing objects for this. - ArrayRef results; + ArrayRef results; /// The extents along each of the range dimensions if the map is bounded, /// nullptr otherwise. - ArrayRef rangeSizes; + ArrayRef rangeSizes; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 1051b8022435..5c06182c6bdd 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -26,12 +26,12 @@ namespace mlir { namespace detail { -class AffineExpr; +class AffineExprClass; } // namespace detail -template class AffineExprBaseRef; -using AffineExprRef = AffineExprBaseRef; +template class AffineExprBase; +using AffineExpr = AffineExprBase; class MLIRContext; class Module; class UnknownLoc; @@ -108,25 +108,25 @@ public: FunctionAttr *getFunctionAttr(const Function *value); // Affine expressions and affine maps. - AffineExprRef getAffineDimExpr(unsigned position); - AffineExprRef getAffineSymbolExpr(unsigned position); - AffineExprRef getAffineConstantExpr(int64_t constant); - AffineExprRef getAddExpr(AffineExprRef lhs, AffineExprRef rhs); - AffineExprRef getAddExpr(AffineExprRef lhs, int64_t rhs); - AffineExprRef getSubExpr(AffineExprRef lhs, AffineExprRef rhs); - AffineExprRef getSubExpr(AffineExprRef lhs, int64_t rhs); - AffineExprRef getMulExpr(AffineExprRef lhs, AffineExprRef rhs); - AffineExprRef getMulExpr(AffineExprRef lhs, int64_t rhs); - AffineExprRef getModExpr(AffineExprRef lhs, AffineExprRef rhs); - AffineExprRef getModExpr(AffineExprRef lhs, uint64_t rhs); - AffineExprRef getFloorDivExpr(AffineExprRef lhs, AffineExprRef rhs); - AffineExprRef getFloorDivExpr(AffineExprRef lhs, uint64_t rhs); - AffineExprRef getCeilDivExpr(AffineExprRef lhs, AffineExprRef rhs); - AffineExprRef getCeilDivExpr(AffineExprRef lhs, uint64_t rhs); + AffineExpr getAffineDimExpr(unsigned position); + AffineExpr getAffineSymbolExpr(unsigned position); + AffineExpr getAffineConstantExpr(int64_t constant); + AffineExpr getAddExpr(AffineExpr lhs, AffineExpr rhs); + AffineExpr getAddExpr(AffineExpr lhs, int64_t rhs); + AffineExpr getSubExpr(AffineExpr lhs, AffineExpr rhs); + AffineExpr getSubExpr(AffineExpr lhs, int64_t rhs); + AffineExpr getMulExpr(AffineExpr lhs, AffineExpr rhs); + AffineExpr getMulExpr(AffineExpr lhs, int64_t rhs); + AffineExpr getModExpr(AffineExpr lhs, AffineExpr rhs); + AffineExpr getModExpr(AffineExpr lhs, uint64_t rhs); + AffineExpr getFloorDivExpr(AffineExpr lhs, AffineExpr rhs); + AffineExpr getFloorDivExpr(AffineExpr lhs, uint64_t rhs); + AffineExpr getCeilDivExpr(AffineExpr lhs, AffineExpr rhs); + AffineExpr getCeilDivExpr(AffineExpr lhs, uint64_t rhs); AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount, - ArrayRef results, - ArrayRef rangeSizes); + ArrayRef results, + ArrayRef rangeSizes); // Special cases of affine maps and integer sets /// Returns a single constant result affine map with 0 dimensions and 0 @@ -151,7 +151,7 @@ public: // Integer set. IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount, - ArrayRef constraints, + ArrayRef constraints, ArrayRef isEq); // TODO: Helpers for affine map/exprs, etc. protected: diff --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h index a7922583ba40..cb7eec841ced 100644 --- a/mlir/include/mlir/IR/IntegerSet.h +++ b/mlir/include/mlir/IR/IntegerSet.h @@ -47,7 +47,7 @@ class MLIRContext; class IntegerSet { public: static IntegerSet *get(unsigned dimCount, unsigned symbolCount, - ArrayRef constraints, + ArrayRef constraints, ArrayRef eqFlags, MLIRContext *context); unsigned getNumDims() { return dimCount; } @@ -55,9 +55,9 @@ public: unsigned getNumOperands() { return dimCount + symbolCount; } unsigned getNumConstraints() { return numConstraints; } - ArrayRef getConstraints() { return constraints; } + ArrayRef getConstraints() { return constraints; } - AffineExprRef getConstraint(unsigned idx) { return getConstraints()[idx]; } + AffineExpr getConstraint(unsigned idx) { return getConstraints()[idx]; } /// Returns the equality bits, which specify whether each of the constraints /// is an equality or inequality. @@ -72,7 +72,7 @@ public: private: IntegerSet(unsigned dimCount, unsigned symbolCount, unsigned numConstraints, - ArrayRef constraints, ArrayRef eqFlags); + ArrayRef constraints, ArrayRef eqFlags); ~IntegerSet() = delete; @@ -82,7 +82,7 @@ private: /// Array of affine constraints: a constaint is either an equality /// (affine_expr == 0) or an inequality (affine_expr >= 0). - ArrayRef constraints; + ArrayRef constraints; // Bits to check whether a constraint is an equality or an inequality. ArrayRef eqFlags; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 84d6eed7e8f8..7eb3be3bfe3a 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -31,11 +31,11 @@ namespace mlir { namespace detail { -class AffineExpr; +class AffineExprClass; } // namespace detail -template class AffineExprBaseRef; -using AffineExprRef = AffineExprBaseRef; +template class AffineExprBase; +using AffineExpr = AffineExprBase; class AffineMap; class Builder; class Function; @@ -76,7 +76,7 @@ public: virtual void printFunctionReference(const Function *func) = 0; virtual void printAttribute(const Attribute *attr) = 0; virtual void printAffineMap(AffineMap *map) = 0; - virtual void printAffineExpr(AffineExprRef expr) = 0; + virtual void printAffineExpr(AffineExpr expr) = 0; /// If the specified operation has attributes, print out an attribute /// dictionary with their values. elidedAttrs allows the client to ignore diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 2d314e0771a8..fa2541a4fd71 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -33,14 +33,14 @@ using namespace llvm; /// Constructs an affine expression from a flat ArrayRef. If there are local /// identifiers (neither dimensional nor symbolic) that appear in the sum of -/// products expression, 'localExprs' is expected to have the AffineExpr for it, -/// and is substituted into. The ArrayRef 'eq' is expected to be in the format -/// [dims, symbols, locals, constant term]. +/// products expression, 'localExprs' is expected to have the AffineExprClass +/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the +/// format [dims, symbols, locals, constant term]. // TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here. -static AffineExprRef toAffineExpr(ArrayRef eq, unsigned numDims, - unsigned numSymbols, - ArrayRef localExprs, - MLIRContext *context) { +static AffineExpr toAffineExpr(ArrayRef eq, unsigned numDims, + unsigned numSymbols, + ArrayRef localExprs, + MLIRContext *context) { // Assert expected numLocals = eq.size() - numDims - numSymbols - 1 assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() && "unexpected number of local expressions"); @@ -74,7 +74,7 @@ static AffineExprRef toAffineExpr(ArrayRef eq, unsigned numDims, namespace { -// This class is used to flatten a pure affine expression (AffineExprRef, +// This class is used to flatten a pure affine expression (AffineExpr, // which is in a tree form) into a sum of products (w.r.t constants) when // possible, and in that process simplifying the expression. The simplification // performed includes the accumulation of contributions for each dimensional and @@ -127,14 +127,14 @@ public: // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv // expressions that could not be simplified. unsigned numLocals; - // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for + // AffineExprClass's corresponding to the floordiv/ceildiv/mod expressions for // which new identifiers were introduced; if the latter do not get canceled - // out, these expressions are needed to reconstruct the AffineExprRef / tree + // out, these expressions are needed to reconstruct the AffineExpr / tree // form. Note that these expressions themselves would have been simplified // (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 will be // simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) ceildiv 2 // would be the local expression stored for q. - SmallVector localExprs; + SmallVector localExprs; MLIRContext *context; AffineExprFlattener(unsigned numDims, unsigned numSymbols, @@ -144,10 +144,10 @@ public: operandExprStack.reserve(8); } - void visitMulExpr(AffineBinaryOpExprRef expr) { + void visitMulExpr(AffineBinaryOpExpr expr) { assert(operandExprStack.size() >= 2); // This is a pure affine expr; the RHS will be a constant. - assert(expr->getRHS().isa()); + assert(expr->getRHS().isa()); // Get the RHS constant. auto rhsConst = operandExprStack.back()[getConstantIndex()]; operandExprStack.pop_back(); @@ -158,7 +158,7 @@ public: } } - void visitAddExpr(AffineBinaryOpExprRef expr) { + void visitAddExpr(AffineBinaryOpExpr expr) { assert(operandExprStack.size() >= 2); const auto &rhs = operandExprStack.back(); auto &lhs = operandExprStack[operandExprStack.size() - 2]; @@ -171,10 +171,10 @@ public: operandExprStack.pop_back(); } - void visitModExpr(AffineBinaryOpExprRef expr) { + void visitModExpr(AffineBinaryOpExpr expr) { assert(operandExprStack.size() >= 2); // This is a pure affine expr; the RHS will be a constant. - assert(expr->getRHS().isa()); + assert(expr->getRHS().isa()); auto rhsConst = operandExprStack.back()[getConstantIndex()]; operandExprStack.pop_back(); auto &lhs = operandExprStack.back(); @@ -200,32 +200,32 @@ public: addLocalId(a.floorDiv(b)); lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; } - void visitCeilDivExpr(AffineBinaryOpExprRef expr) { + void visitCeilDivExpr(AffineBinaryOpExpr expr) { visitDivExpr(expr, /*isCeil=*/true); } - void visitFloorDivExpr(AffineBinaryOpExprRef expr) { + void visitFloorDivExpr(AffineBinaryOpExpr expr) { visitDivExpr(expr, /*isCeil=*/false); } - void visitDimExpr(AffineDimExprRef expr) { + void visitDimExpr(AffineDimExpr expr) { operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); auto &eq = operandExprStack.back(); eq[getDimStartIndex() + expr->getPosition()] = 1; } - void visitSymbolExpr(AffineSymbolExprRef expr) { + void visitSymbolExpr(AffineSymbolExpr expr) { operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); auto &eq = operandExprStack.back(); eq[getSymbolStartIndex() + expr->getPosition()] = 1; } - void visitConstantExpr(AffineConstantExprRef expr) { + void visitConstantExpr(AffineConstantExpr expr) { operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); auto &eq = operandExprStack.back(); eq[getConstantIndex()] = expr->getValue(); } private: - void visitDivExpr(AffineBinaryOpExprRef expr, bool isCeil) { + void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) { assert(operandExprStack.size() >= 2); - assert(expr->getRHS().isa()); + assert(expr->getRHS().isa()); // This is a pure affine expr; the RHS is a positive constant. auto rhsConst = operandExprStack.back()[getConstantIndex()]; // TODO(bondhugula): handle division by zero at the same time the issue is @@ -266,9 +266,9 @@ private: } // Add an existential quantifier (used to flatten a mod, floordiv, ceildiv - // expr). localExpr is the simplified tree expression (AffineExprRef) + // expr). localExpr is the simplified tree expression (AffineExpr) // corresponding to the quantifier. - void addLocalId(AffineExprRef localExpr) { + void addLocalId(AffineExpr localExpr) { for (auto &subExpr : operandExprStack) { subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); } @@ -284,8 +284,8 @@ private: } // end anonymous namespace -AffineExprRef mlir::simplifyAffineExpr(AffineExprRef expr, unsigned numDims, - unsigned numSymbols) { +AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols) { // TODO(bondhugula): only pure affine for now. The simplification here can be // extended to semi-affine maps in the future. if (!expr->isPureAffine()) diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index ac0ba3b0bbf8..463a64b3b2c6 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -47,7 +47,7 @@ namespace { struct AffineMapCompositionUpdate { using PositionMap = DenseMap; - explicit AffineMapCompositionUpdate(ArrayRef inputResults) + explicit AffineMapCompositionUpdate(ArrayRef inputResults) : inputResults(inputResults), outputNumDims(0), outputNumSymbols(0) {} // Map from 'curr' affine map dim position to 'output' affine map @@ -65,7 +65,7 @@ struct AffineMapCompositionUpdate { // symbol position. PositionMap inputSymbolMap; // Results of 'input' affine map. - ArrayRef inputResults; + ArrayRef inputResults; // Number of dimension operands for 'output' affine map. unsigned outputNumDims; // Number of symbol operands for 'output' affine map. @@ -80,29 +80,29 @@ public: AffineExprComposer(const AffineMapCompositionUpdate &mapUpdate) : mapUpdate(mapUpdate), walkingInputMap(false) {} - AffineExprRef walk(AffineExprRef expr) { + AffineExpr walk(AffineExpr expr) { switch (expr->getKind()) { case AffineExprKind::Add: return walkBinExpr( - expr, [](AffineExprRef lhs, AffineExprRef rhs) { return lhs + rhs; }); + expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs + rhs; }); case AffineExprKind::Mul: return walkBinExpr( - expr, [](AffineExprRef lhs, AffineExprRef rhs) { return lhs * rhs; }); + expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs * rhs; }); case AffineExprKind::Mod: return walkBinExpr( - expr, [](AffineExprRef lhs, AffineExprRef rhs) { return lhs % rhs; }); + expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs % rhs; }); case AffineExprKind::FloorDiv: - return walkBinExpr(expr, [](AffineExprRef lhs, AffineExprRef rhs) { + return walkBinExpr(expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs.floorDiv(rhs); }); case AffineExprKind::CeilDiv: - return walkBinExpr(expr, [](AffineExprRef lhs, AffineExprRef rhs) { + return walkBinExpr(expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs.ceilDiv(rhs); }); case AffineExprKind::Constant: return expr; case AffineExprKind::DimId: { - unsigned dimPosition = expr.cast()->getPosition(); + unsigned dimPosition = expr.cast()->getPosition(); if (walkingInputMap) { return getAffineDimExpr(mapUpdate.inputDimMap.lookup(dimPosition), expr->getContext()); @@ -123,7 +123,7 @@ public: return composer.walk(mapUpdate.inputResults[inputResultIndex]); } case AffineExprKind::SymbolId: - unsigned symbolPosition = expr.cast()->getPosition(); + unsigned symbolPosition = expr.cast()->getPosition(); if (walkingInputMap) { return getAffineSymbolExpr( mapUpdate.inputSymbolMap.lookup(symbolPosition), @@ -139,10 +139,9 @@ private: bool walkingInputMap) : mapUpdate(mapUpdate), walkingInputMap(walkingInputMap) {} - AffineExprRef - walkBinExpr(AffineExprRef expr, - std::function op) { - auto binOpExpr = expr.cast(); + AffineExpr walkBinExpr(AffineExpr expr, + std::function op) { + auto binOpExpr = expr.cast(); return op(walk(binOpExpr->getLHS()), walk(binOpExpr->getRHS())); } @@ -197,7 +196,7 @@ void MutableAffineMap::simplify() { } AffineMap *MutableAffineMap::getAffineMap() { - return AffineMap::get(numDims, numSymbols, results, rangeSizes, context); + return AffineMap::get(numDims, numSymbols, results, rangeSizes); } MutableIntegerSet::MutableIntegerSet(IntegerSet *set, MLIRContext *context) @@ -295,10 +294,10 @@ void AffineValueMap::fwdSubstitute(const AffineApplyOp &inputOp) { DenseSet *positions; AffineExprPositionGatherer(unsigned numDims, DenseSet *positions) : numDims(numDims), positions(positions) {} - void visitDimExpr(AffineDimExprRef expr) { + void visitDimExpr(AffineDimExpr expr) { positions->insert(expr->getPosition()); } - void visitSymbolExpr(AffineSymbolExprRef expr) { + void visitSymbolExpr(AffineSymbolExpr expr) { positions->insert(numDims + expr->getPosition()); } }; diff --git a/mlir/lib/Analysis/HyperRectangularSet.cpp b/mlir/lib/Analysis/HyperRectangularSet.cpp index 4d7280895ef0..7fc5b2932779 100644 --- a/mlir/lib/Analysis/HyperRectangularSet.cpp +++ b/mlir/lib/Analysis/HyperRectangularSet.cpp @@ -38,7 +38,7 @@ getReducedConstBound(const HyperRectangularSet &set, unsigned *idx, unsigned j = 0; AffineBoundExprList::const_iterator it, e; for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) { - if (auto cExpr = it->dyn_cast()) { + if (auto cExpr = it->dyn_cast()) { if (val == None) { val = cExpr->getValue(); *idx = j; @@ -52,8 +52,9 @@ getReducedConstBound(const HyperRectangularSet &set, unsigned *idx, return val; } -// Merge the two lists of AffineExpr's into a single one, avoiding duplicates. -// lb specifies whether the bound lists are for a lower bound or an upper bound. +// Merge the two lists of AffineExprClass's into a single one, avoiding +// duplicates. lb specifies whether the bound lists are for a lower bound or an +// upper bound. // TODO(bondhugula): clean this code up. static void mergeBounds(const HyperRectangularSet &set, AffineBoundExprList &lhsList, @@ -68,7 +69,7 @@ static void mergeBounds(const HyperRectangularSet &set, } if (it == lhsList.end()) { // There can only be one constant affine expr in this bound list. - if (auto cExpr = expr.dyn_cast()) { + if (auto cExpr = expr.dyn_cast()) { unsigned idx; if (lb) { auto cb = getReducedConstBound( @@ -105,8 +106,8 @@ static void mergeBounds(const HyperRectangularSet &set, } HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols, - ArrayRef> lbs, - ArrayRef> ubs, + ArrayRef> lbs, + ArrayRef> ubs, MLIRContext *context, IntegerSet *symbolContext) : context(symbolContext ? MutableIntegerSet(symbolContext, context) diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 522720eb362e..b3e3afef19ba 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -32,7 +32,7 @@ using namespace mlir; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count /// expression is simplified before returning. -AffineExprRef mlir::getTripCountExpr(const ForStmt &forStmt) { +AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) { // upper_bound - lower_bound + 1 int64_t loopSpan; @@ -56,12 +56,12 @@ AffineExprRef mlir::getTripCountExpr(const ForStmt &forStmt) { return nullptr; // ub_expr - lb_expr + 1 - AffineExprRef lbExpr(lbMap->getResult(0)); - AffineExprRef ubExpr(ubMap->getResult(0)); + AffineExpr lbExpr(lbMap->getResult(0)); + AffineExpr ubExpr(ubMap->getResult(0)); auto loopSpanExpr = simplifyAffineExpr( ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()), std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols())); - auto cExpr = loopSpanExpr.dyn_cast(); + auto cExpr = loopSpanExpr.dyn_cast(); if (!cExpr) return loopSpanExpr.ceilDiv(step); loopSpan = cExpr->getValue(); @@ -84,7 +84,7 @@ llvm::Optional mlir::getConstantTripCount(const ForStmt &forStmt) { if (!tripCountExpr) return None; - if (auto constExpr = tripCountExpr.dyn_cast()) + if (auto constExpr = tripCountExpr.dyn_cast()) return constExpr->getValue(); return None; @@ -99,7 +99,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { if (!tripCountExpr) return 1; - if (auto constExpr = tripCountExpr.dyn_cast()) { + if (auto constExpr = tripCountExpr.dyn_cast()) { uint64_t tripCount = constExpr->getValue(); // 0 iteration loops (greatest divisor is 2^64 - 1). diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index d03b802fc169..204bbcb21bb5 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -24,7 +24,7 @@ using namespace mlir::detail; /// Returns true if this expression is made out of only symbols and /// constants (no dimensional identifiers). -bool AffineExpr::isSymbolicOrConstant() { +bool AffineExprClass::isSymbolicOrConstant() { switch (getKind()) { case AffineExprKind::Constant: return true; @@ -38,7 +38,7 @@ bool AffineExpr::isSymbolicOrConstant() { case AffineExprKind::FloorDiv: case AffineExprKind::CeilDiv: case AffineExprKind::Mod: { - auto *expr = cast(this); + auto *expr = cast(this); return expr->getLHS()->isSymbolicOrConstant() && expr->getRHS()->isSymbolicOrConstant(); } @@ -47,111 +47,104 @@ bool AffineExpr::isSymbolicOrConstant() { ////////////////////////////////// Details ///////////////////////////////////// -AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExprKind kind, AffineExprRef lhs, - AffineExprRef rhs, MLIRContext *context) - : AffineExpr(kind, context), lhs(lhs), rhs(rhs) { +AffineBinaryOpExprClass::AffineBinaryOpExprClass(AffineExprKind kind, + AffineExpr lhs, AffineExpr rhs) + : AffineExprClass(kind, lhs->getContext()), lhs(lhs), rhs(rhs) { // We verify affine op expr forms at construction time. switch (kind) { case AffineExprKind::Add: - assert(!lhs.isa()); + assert(!lhs.isa()); break; case AffineExprKind::Mul: - assert(!lhs.isa()); - assert(AffineExprRef(rhs)->isSymbolicOrConstant()); + assert(!lhs.isa()); + assert(AffineExpr(rhs)->isSymbolicOrConstant()); break; case AffineExprKind::FloorDiv: - assert(AffineExprRef(rhs)->isSymbolicOrConstant()); + assert(AffineExpr(rhs)->isSymbolicOrConstant()); break; case AffineExprKind::CeilDiv: - assert(AffineExprRef(rhs)->isSymbolicOrConstant()); + assert(AffineExpr(rhs)->isSymbolicOrConstant()); break; case AffineExprKind::Mod: - assert(AffineExprRef(rhs)->isSymbolicOrConstant()); + assert(AffineExpr(rhs)->isSymbolicOrConstant()); break; default: llvm_unreachable("unexpected binary affine expr"); } } -AffineExprRef AffineBinaryOpExpr::getSub(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context) { - return getAdd(lhs, getMul(rhs, getAffineConstantExpr(-1, context), context), - context); +AffineExpr AffineBinaryOpExprClass::getSub(AffineExpr lhs, AffineExpr rhs) { + return getAdd(lhs, getMul(rhs, getAffineConstantExpr(-1, lhs->getContext()))); } -AffineExprRef AffineBinaryOpExpr::getAdd(AffineExprRef expr, int64_t rhs, - MLIRContext *context) { - return get(AffineExprKind::Add, expr, getAffineConstantExpr(rhs, context), - context); +AffineExpr AffineBinaryOpExprClass::getAdd(AffineExpr expr, int64_t rhs) { + return get(AffineExprKind::Add, expr, + getAffineConstantExpr(rhs, expr->getContext())); } -AffineExprRef AffineBinaryOpExpr::getMul(AffineExprRef expr, int64_t rhs, - MLIRContext *context) { - return get(AffineExprKind::Mul, expr, getAffineConstantExpr(rhs, context), - context); +AffineExpr AffineBinaryOpExprClass::getMul(AffineExpr expr, int64_t rhs) { + return get(AffineExprKind::Mul, expr, + getAffineConstantExpr(rhs, expr->getContext())); } -AffineExprRef AffineBinaryOpExpr::getFloorDiv(AffineExprRef lhs, uint64_t rhs, - MLIRContext *context) { - return get(AffineExprKind::FloorDiv, lhs, getAffineConstantExpr(rhs, context), - context); +AffineExpr AffineBinaryOpExprClass::getFloorDiv(AffineExpr lhs, uint64_t rhs) { + return get(AffineExprKind::FloorDiv, lhs, + getAffineConstantExpr(rhs, lhs->getContext())); } -AffineExprRef AffineBinaryOpExpr::getCeilDiv(AffineExprRef lhs, uint64_t rhs, - MLIRContext *context) { - return get(AffineExprKind::CeilDiv, lhs, getAffineConstantExpr(rhs, context), - context); +AffineExpr AffineBinaryOpExprClass::getCeilDiv(AffineExpr lhs, uint64_t rhs) { + return get(AffineExprKind::CeilDiv, lhs, + getAffineConstantExpr(rhs, lhs->getContext())); } -AffineExprRef AffineBinaryOpExpr::getMod(AffineExprRef lhs, uint64_t rhs, - MLIRContext *context) { - return get(AffineExprKind::Mod, lhs, getAffineConstantExpr(rhs, context), - context); +AffineExpr AffineBinaryOpExprClass::getMod(AffineExpr lhs, uint64_t rhs) { + return get(AffineExprKind::Mod, lhs, + getAffineConstantExpr(rhs, lhs->getContext())); } /// Returns true if this is a pure affine expression, i.e., multiplication, /// floordiv, ceildiv, and mod is only allowed w.r.t constants. -bool AffineExpr::isPureAffine() { +bool AffineExprClass::isPureAffine() { switch (getKind()) { case AffineExprKind::SymbolId: case AffineExprKind::DimId: case AffineExprKind::Constant: return true; case AffineExprKind::Add: { - auto *op = cast(this); + auto *op = cast(this); return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine(); } case AffineExprKind::Mul: { // TODO: Canonicalize the constants in binary operators to the RHS when // possible, allowing this to merge into the next case. - auto *op = cast(this); + auto *op = cast(this); return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine() && - (op->getLHS().isa() || - op->getRHS().isa()); + (op->getLHS().isa() || + op->getRHS().isa()); } case AffineExprKind::FloorDiv: case AffineExprKind::CeilDiv: case AffineExprKind::Mod: { - auto *op = cast(this); + auto *op = cast(this); return op->getLHS()->isPureAffine() && - op->getRHS().isa(); + op->getRHS().isa(); } } } /// Returns the greatest known integral divisor of this affine expression. -uint64_t AffineExpr::getLargestKnownDivisor() { - AffineBinaryOpExprRef binExpr; +uint64_t AffineExprClass::getLargestKnownDivisor() { + AffineBinaryOpExpr binExpr; switch (getKind()) { case AffineExprKind::SymbolId: LLVM_FALLTHROUGH; case AffineExprKind::DimId: return 1; case AffineExprKind::Constant: - return std::abs(cast(this)->getValue()); + return std::abs(cast(this)->getValue()); case AffineExprKind::Mul: { - binExpr = cast(this); + binExpr = cast(this); return binExpr->getLHS()->getLargestKnownDivisor() * binExpr->getRHS()->getLargestKnownDivisor(); } @@ -160,7 +153,7 @@ uint64_t AffineExpr::getLargestKnownDivisor() { case AffineExprKind::FloorDiv: case AffineExprKind::CeilDiv: case AffineExprKind::Mod: { - binExpr = cast(this); + binExpr = cast(this); return llvm::GreatestCommonDivisor64( binExpr->getLHS()->getLargestKnownDivisor(), binExpr->getRHS()->getLargestKnownDivisor()); @@ -168,8 +161,8 @@ uint64_t AffineExpr::getLargestKnownDivisor() { } } -bool AffineExpr::isMultipleOf(int64_t factor) { - AffineBinaryOpExpr *binExpr; +bool AffineExprClass::isMultipleOf(int64_t factor) { + AffineBinaryOpExprClass *binExpr; uint64_t l, u; switch (getKind()) { case AffineExprKind::SymbolId: @@ -177,9 +170,9 @@ bool AffineExpr::isMultipleOf(int64_t factor) { case AffineExprKind::DimId: return factor * factor == 1; case AffineExprKind::Constant: - return cast(this)->getValue() % factor == 0; + return cast(this)->getValue() % factor == 0; case AffineExprKind::Mul: { - binExpr = cast(this); + binExpr = cast(this); // It's probably not worth optimizing this further (to not traverse the // whole sub-tree under - it that would require a version of isMultipleOf // that on a 'false' return also returns the largest known divisor). @@ -191,7 +184,7 @@ bool AffineExpr::isMultipleOf(int64_t factor) { case AffineExprKind::FloorDiv: case AffineExprKind::CeilDiv: case AffineExprKind::Mod: { - binExpr = cast(this); + binExpr = cast(this); return llvm::GreatestCommonDivisor64( binExpr->getLHS()->getLargestKnownDivisor(), binExpr->getRHS()->getLargestKnownDivisor()) % @@ -201,48 +194,56 @@ bool AffineExpr::isMultipleOf(int64_t factor) { } } -MLIRContext *AffineExpr::getContext() { return context; } +MLIRContext *AffineExprClass::getContext() { return context; } ///////////////////////////// Done with details /////////////////////////////// -template <> AffineExprRef AffineExprRef::operator+(int64_t v) const { - return AffineBinaryOpExpr::getAdd(expr, v, expr->getContext()); +template <> AffineExpr AffineExpr::operator+(int64_t v) const { + return AffineBinaryOpExprClass::getAdd(expr, v); } -template <> AffineExprRef AffineExprRef::operator+(AffineExprRef other) const { - return AffineBinaryOpExpr::getAdd(expr, other.expr, expr->getContext()); +template <> AffineExpr AffineExpr::operator+(AffineExpr other) const { + return AffineBinaryOpExprClass::getAdd(expr, other.expr); } -template <> AffineExprRef AffineExprRef::operator*(int64_t v) const { - return AffineBinaryOpExpr::getMul(expr, v, expr->getContext()); +template <> AffineExpr AffineExpr::operator*(int64_t v) const { + return AffineBinaryOpExprClass::getMul(expr, v); } -template <> AffineExprRef AffineExprRef::operator*(AffineExprRef other) const { - return AffineBinaryOpExpr::getMul(expr, other.expr, expr->getContext()); +template <> AffineExpr AffineExpr::operator*(AffineExpr other) const { + return AffineBinaryOpExprClass::getMul(expr, other.expr); } // Unary minus, delegate to operator*. -template <> AffineExprRef AffineExprRef::operator-() const { - return AffineBinaryOpExpr::getMul(expr, -1, expr->getContext()); +template <> AffineExpr AffineExpr::operator-() const { + return AffineBinaryOpExprClass::getMul(expr, -1); } // Delegate to operator+. -template <> AffineExprRef AffineExprRef::operator-(int64_t v) const { +template <> AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); } -template <> AffineExprRef AffineExprRef::operator-(AffineExprRef other) const { +template <> AffineExpr AffineExpr::operator-(AffineExpr other) const { return *this + (-other); } -template <> AffineExprRef AffineExprRef::floorDiv(uint64_t v) const { - return AffineBinaryOpExpr::getFloorDiv(expr, v, expr->getContext()); +template <> AffineExpr AffineExpr::floorDiv(uint64_t v) const { + return AffineBinaryOpExprClass::getFloorDiv(expr, v); } -template <> AffineExprRef AffineExprRef::floorDiv(AffineExprRef other) const { - return AffineBinaryOpExpr::getFloorDiv(expr, other.expr, expr->getContext()); +template <> AffineExpr AffineExpr::floorDiv(AffineExpr other) const { + return AffineBinaryOpExprClass::getFloorDiv(expr, other.expr); } -template <> AffineExprRef AffineExprRef::ceilDiv(uint64_t v) const { - return AffineBinaryOpExpr::getCeilDiv(expr, v, expr->getContext()); +template <> AffineExpr AffineExpr::ceilDiv(uint64_t v) const { + return AffineBinaryOpExprClass::getCeilDiv(expr, v); } -template <> AffineExprRef AffineExprRef::ceilDiv(AffineExprRef other) const { - return AffineBinaryOpExpr::getCeilDiv(expr, other.expr, expr->getContext()); +template <> AffineExpr AffineExpr::ceilDiv(AffineExpr other) const { + return AffineBinaryOpExprClass::getCeilDiv(expr, other.expr); } -template <> AffineExprRef AffineExprRef::operator%(uint64_t v) const { - return AffineBinaryOpExpr::getMod(expr, v, expr->getContext()); +template <> AffineExpr AffineExpr::operator%(uint64_t v) const { + return AffineBinaryOpExprClass::getMod(expr, v); } -template <> AffineExprRef AffineExprRef::operator%(AffineExprRef other) const { - return AffineBinaryOpExpr::getMod(expr, other.expr, expr->getContext()); +template <> AffineExpr AffineExpr::operator%(AffineExpr other) const { + return AffineBinaryOpExprClass::getMod(expr, other.expr); +} + +AffineExpr operator+(int64_t val, AffineExpr expr) { + return expr + val; // AffineBinaryOpExpr asserts !lhs.isa +} +AffineExpr operator-(int64_t val, AffineExpr expr) { return expr * (-1) + val; } +AffineExpr operator*(int64_t val, AffineExpr expr) { + return expr * val; // AffineBinaryOpExpr asserts !lhs.isa } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 3842d9dbe949..e6939562dbbf 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -37,7 +37,7 @@ public: /// Attempt to constant fold the specified affine expr, or return null on /// failure. - IntegerAttr *constantFold(AffineExprRef expr) { + IntegerAttr *constantFold(AffineExpr expr) { switch (expr->getKind()) { case AffineExprKind::Add: return constantFoldBinExpr( @@ -55,23 +55,23 @@ public: return constantFoldBinExpr( expr, [](int64_t lhs, uint64_t rhs) { return ceilDiv(lhs, rhs); }); case AffineExprKind::Constant: - return IntegerAttr::get(expr.cast()->getValue(), + return IntegerAttr::get(expr.cast()->getValue(), expr->getContext()); case AffineExprKind::DimId: return dyn_cast_or_null( - operandConsts[expr.cast()->getPosition()]); + operandConsts[expr.cast()->getPosition()]); case AffineExprKind::SymbolId: return dyn_cast_or_null( operandConsts[numDims + - expr.cast()->getPosition()]); + expr.cast()->getPosition()]); } } private: IntegerAttr * - constantFoldBinExpr(AffineExprRef expr, + constantFoldBinExpr(AffineExpr expr, std::function op) { - auto binOpExpr = expr.cast(); + auto binOpExpr = expr.cast(); auto *lhs = constantFold(binOpExpr->getLHS()); auto *rhs = constantFold(binOpExpr->getRHS()); if (!lhs || !rhs) @@ -89,23 +89,23 @@ private: } // end anonymous namespace AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults, - ArrayRef results, - ArrayRef rangeSizes) + ArrayRef results, + ArrayRef rangeSizes) : numDims(numDims), numSymbols(numSymbols), numResults(numResults), results(results), rangeSizes(rangeSizes) {} /// Returns a single constant result affine map. AffineMap *AffineMap::getConstantMap(int64_t val, MLIRContext *context) { return get(/*dimCount=*/0, /*symbolCount=*/0, - {getAffineConstantExpr(val, context)}, {}, context); + {getAffineConstantExpr(val, context)}, {}); } bool AffineMap::isIdentity() { if (getNumDims() != getNumResults()) return false; - ArrayRef results = getResults(); + ArrayRef results = getResults(); for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) { - auto expr = results[i].dyn_cast(); + auto expr = results[i].dyn_cast(); if (!expr || expr->getPosition() != i) return false; } @@ -113,15 +113,15 @@ bool AffineMap::isIdentity() { } bool AffineMap::isSingleConstant() { - return getNumResults() == 1 && getResult(0).isa(); + return getNumResults() == 1 && getResult(0).isa(); } int64_t AffineMap::getSingleConstantResult() { assert(isSingleConstant() && "map must have a single constant result"); - return getResult(0).cast()->getValue(); + return getResult(0).cast()->getValue(); } -AffineExprRef AffineMap::getResult(unsigned idx) { return results[idx]; } +AffineExpr AffineMap::getResult(unsigned idx) { return results[idx]; } /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. Returns false if the folding happens, @@ -132,7 +132,7 @@ bool AffineMap::constantFold(ArrayRef operandConstants, // Fold each of the result expressions. AffineExprConstantFolder exprFolder(getNumDims(), operandConstants); - // Constant fold each AffineExpr in AffineMap and add to 'results'. + // Constant fold each AffineExprClass in AffineMap and add to 'results'. for (auto expr : getResults()) { auto *folded = exprFolder.constantFold(expr); // If we didn't fold to a constant, then folding fails. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 8eb3d38d4a19..a491959f645b 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -108,8 +108,8 @@ private: // Check if the affine map is single dim id or single symbol identity - // (i)->(i) or ()[s]->(i) return boundMap->getNumInputs() == 1 && boundMap->getNumResults() == 1 && - (boundMap->getResult(0).isa() || - boundMap->getResult(0).isa()); + (boundMap->getResult(0).isa() || + boundMap->getResult(0).isa()); } // Visit functions. @@ -275,8 +275,8 @@ public: void print(const MLFunction *fn); void printAffineMap(AffineMap *map); - void printAffineExpr(AffineExprRef expr); - void printAffineConstraint(AffineExprRef expr, bool isEq); + void printAffineExpr(AffineExpr expr); + void printAffineConstraint(AffineExpr expr, bool isEq); void printIntegerSet(IntegerSet *set); protected: @@ -294,13 +294,13 @@ protected: void printIntegerSetReference(IntegerSet *integerSet); /// This enum is used to represent the binding stength of the enclosing - /// context that an AffineExpr is being printed in, so we can intelligently - /// produce parens. + /// context that an AffineExprClass is being printed in, so we can + /// intelligently produce parens. enum class BindingStrength { Weak, // + and - Strong, // All other binary operators. }; - void printAffineExprInternal(AffineExprRef expr, + void printAffineExprInternal(AffineExpr expr, BindingStrength enclosingTightness); }; } // end anonymous namespace @@ -571,22 +571,22 @@ void ModulePrinter::printType(const Type *type) { // Affine expressions and maps //===----------------------------------------------------------------------===// -void ModulePrinter::printAffineExpr(AffineExprRef expr) { +void ModulePrinter::printAffineExpr(AffineExpr expr) { printAffineExprInternal(expr, BindingStrength::Weak); } void ModulePrinter::printAffineExprInternal( - AffineExprRef expr, BindingStrength enclosingTightness) { + AffineExpr expr, BindingStrength enclosingTightness) { const char *binopSpelling = nullptr; switch (expr->getKind()) { case AffineExprKind::SymbolId: - os << 's' << expr.cast()->getPosition(); + os << 's' << expr.cast()->getPosition(); return; case AffineExprKind::DimId: - os << 'd' << expr.cast()->getPosition(); + os << 'd' << expr.cast()->getPosition(); return; case AffineExprKind::Constant: - os << expr.cast()->getValue(); + os << expr.cast()->getValue(); return; case AffineExprKind::Add: binopSpelling = " + "; @@ -605,7 +605,7 @@ void ModulePrinter::printAffineExprInternal( break; } - auto binOp = expr.cast(); + auto binOp = expr.cast(); // Handle tightly binding binary operators. if (binOp->getKind() != AffineExprKind::Add) { @@ -627,11 +627,11 @@ void ModulePrinter::printAffineExprInternal( // Pretty print addition to a product that has a negative operand as a // subtraction. - AffineExprRef rhsExpr = binOp->getRHS(); - if (auto rhs = rhsExpr.dyn_cast()) { + AffineExpr rhsExpr = binOp->getRHS(); + if (auto rhs = rhsExpr.dyn_cast()) { if (rhs->getKind() == AffineExprKind::Mul) { - AffineExprRef rrhsExpr = rhs->getRHS(); - if (auto rrhs = rrhsExpr.dyn_cast()) { + AffineExpr rrhsExpr = rhs->getRHS(); + if (auto rrhs = rrhsExpr.dyn_cast()) { if (rrhs->getValue() == -1) { printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak); os << " - "; @@ -656,7 +656,7 @@ void ModulePrinter::printAffineExprInternal( } // Pretty print addition to a negative number as a subtraction. - if (auto rhs = rhsExpr.dyn_cast()) { + if (auto rhs = rhsExpr.dyn_cast()) { if (rhs->getValue() < 0) { printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak); os << " - " << -rhs->getValue(); @@ -674,7 +674,7 @@ void ModulePrinter::printAffineExprInternal( os << ')'; } -void ModulePrinter::printAffineConstraint(AffineExprRef expr, bool isEq) { +void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) { printAffineExprInternal(expr, BindingStrength::Weak); isEq ? os << " == 0" : os << " >= 0"; } @@ -703,7 +703,7 @@ void ModulePrinter::printAffineMap(AffineMap *map) { // Result affine expressions. os << " -> ("; interleaveComma(map->getResults(), - [&](AffineExprRef expr) { printAffineExpr(expr); }); + [&](AffineExpr expr) { printAffineExpr(expr); }); os << ')'; if (!map->isBounded()) { @@ -713,7 +713,7 @@ void ModulePrinter::printAffineMap(AffineMap *map) { // Print range sizes for bounded affine maps. os << " size ("; interleaveComma(map->getRangeSizes(), - [&](AffineExprRef expr) { printAffineExpr(expr); }); + [&](AffineExpr expr) { printAffineExpr(expr); }); os << ')'; } @@ -858,7 +858,7 @@ public: void printIntegerSet(IntegerSet *set) { return ModulePrinter::printIntegerSetReference(set); } - void printAffineExpr(AffineExprRef expr) { + void printAffineExpr(AffineExpr expr) { return ModulePrinter::printAffineExpr(expr); } void printFunctionReference(const Function *func) { @@ -1432,11 +1432,11 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) { // Therefore, short-hand parsing and printing is only supported for // zero-operand constant maps and single symbol operand identity maps. if (map->getNumResults() == 1) { - AffineExprRef expr = map->getResult(0); + AffineExpr expr = map->getResult(0); // Print constant bound. if (map->getNumDims() == 0 && map->getNumSymbols() == 0) { - if (auto constExpr = expr.dyn_cast()) { + if (auto constExpr = expr.dyn_cast()) { os << constExpr->getValue(); return; } @@ -1445,7 +1445,7 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) { // Print bound that consists of a single SSA symbol if the map is over a // single symbol. if (map->getNumDims() == 0 && map->getNumSymbols() == 1) { - if (auto symExpr = expr.dyn_cast()) { + if (auto symExpr = expr.dyn_cast()) { printOperand(bound.getOperand(0)); return; } @@ -1502,7 +1502,7 @@ void AffineMap::dump() { llvm::errs() << "\n"; } -void AffineExpr::dump() { +void AffineExprClass::dump() { print(llvm::errs()); llvm::errs() << "\n"; } @@ -1512,7 +1512,7 @@ void IntegerSet::dump() { llvm::errs() << "\n"; } -void AffineExpr::print(raw_ostream &os) { +void AffineExprClass::print(raw_ostream &os) { ModuleState state(/*no context is known*/ nullptr); ModulePrinter(os, state).printAffineExpr(this); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 5ab0fec2d430..7c7ad380aad6 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -150,119 +150,118 @@ FunctionAttr *Builder::getFunctionAttr(const Function *value) { //===----------------------------------------------------------------------===// AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount, - ArrayRef results, - ArrayRef rangeSizes) { - return AffineMap::get(dimCount, symbolCount, results, rangeSizes, context); + ArrayRef results, + ArrayRef rangeSizes) { + return AffineMap::get(dimCount, symbolCount, results, rangeSizes); } -AffineExprRef Builder::getAffineDimExpr(unsigned position) { +AffineExpr Builder::getAffineDimExpr(unsigned position) { return mlir::getAffineDimExpr(position, context); } -AffineExprRef Builder::getAffineSymbolExpr(unsigned position) { +AffineExpr Builder::getAffineSymbolExpr(unsigned position) { return mlir::getAffineSymbolExpr(position, context); } -AffineExprRef Builder::getAffineConstantExpr(int64_t constant) { +AffineExpr Builder::getAffineConstantExpr(int64_t constant) { return mlir::getAffineConstantExpr(constant, context); } -AffineExprRef Builder::getAddExpr(AffineExprRef lhs, AffineExprRef rhs) { +AffineExpr Builder::getAddExpr(AffineExpr lhs, AffineExpr rhs) { return lhs + rhs; } -AffineExprRef Builder::getAddExpr(AffineExprRef lhs, int64_t rhs) { +AffineExpr Builder::getAddExpr(AffineExpr lhs, int64_t rhs) { return lhs + rhs; } -AffineExprRef Builder::getMulExpr(AffineExprRef lhs, AffineExprRef rhs) { +AffineExpr Builder::getMulExpr(AffineExpr lhs, AffineExpr rhs) { return lhs * rhs; } // Most multiply expressions are pure affine (rhs is a constant). -AffineExprRef Builder::getMulExpr(AffineExprRef lhs, int64_t rhs) { +AffineExpr Builder::getMulExpr(AffineExpr lhs, int64_t rhs) { return lhs * rhs; } -AffineExprRef Builder::getSubExpr(AffineExprRef lhs, AffineExprRef rhs) { +AffineExpr Builder::getSubExpr(AffineExpr lhs, AffineExpr rhs) { return lhs - rhs; } -AffineExprRef Builder::getSubExpr(AffineExprRef lhs, int64_t rhs) { +AffineExpr Builder::getSubExpr(AffineExpr lhs, int64_t rhs) { return lhs - rhs; } -AffineExprRef Builder::getModExpr(AffineExprRef lhs, AffineExprRef rhs) { +AffineExpr Builder::getModExpr(AffineExpr lhs, AffineExpr rhs) { return lhs % rhs; } // Most modulo expressions are pure affine. -AffineExprRef Builder::getModExpr(AffineExprRef lhs, uint64_t rhs) { +AffineExpr Builder::getModExpr(AffineExpr lhs, uint64_t rhs) { return lhs % rhs; } -AffineExprRef Builder::getFloorDivExpr(AffineExprRef lhs, AffineExprRef rhs) { +AffineExpr Builder::getFloorDivExpr(AffineExpr lhs, AffineExpr rhs) { return lhs.floorDiv(rhs); } // Most floordiv expressions are pure affine. -AffineExprRef Builder::getFloorDivExpr(AffineExprRef lhs, uint64_t rhs) { +AffineExpr Builder::getFloorDivExpr(AffineExpr lhs, uint64_t rhs) { return lhs.floorDiv(rhs); } -AffineExprRef Builder::getCeilDivExpr(AffineExprRef lhs, AffineExprRef rhs) { +AffineExpr Builder::getCeilDivExpr(AffineExpr lhs, AffineExpr rhs) { return lhs.ceilDiv(rhs); } // Most ceildiv expressions are pure affine. -AffineExprRef Builder::getCeilDivExpr(AffineExprRef lhs, uint64_t rhs) { +AffineExpr Builder::getCeilDivExpr(AffineExpr lhs, uint64_t rhs) { return lhs.ceilDiv(rhs); } IntegerSet *Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount, - ArrayRef constraints, + ArrayRef constraints, ArrayRef isEq) { return IntegerSet::get(dimCount, symbolCount, constraints, isEq, context); } AffineMap *Builder::getConstantAffineMap(int64_t val) { return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, - {getAffineConstantExpr(val)}, {}, context); + {getAffineConstantExpr(val)}, {}); } AffineMap *Builder::getDimIdentityMap() { return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, - {getAffineDimExpr(0)}, {}, context); + {getAffineDimExpr(0)}, {}); } AffineMap *Builder::getDimIdentityMap(unsigned rank) { - SmallVector dimExprs; + SmallVector dimExprs; dimExprs.reserve(rank); for (unsigned i = 0; i < rank; ++i) dimExprs.push_back(getAffineDimExpr(i)); - return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs, {}, - context); + return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs, {}); } AffineMap *Builder::getSymbolIdentityMap() { return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, - {getAffineSymbolExpr(0)}, {}, context); + {getAffineSymbolExpr(0)}, {}); } AffineMap *Builder::getSingleDimShiftAffineMap(int64_t shift) { // expr = d0 + shift. auto expr = getAffineDimExpr(0) + shift; - return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}, {}, context); + return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}, {}); } AffineMap *Builder::getShiftedAffineMap(AffineMap *map, int64_t shift) { - SmallVector shiftedResults; + SmallVector shiftedResults; shiftedResults.reserve(map->getNumResults()); for (auto resultExpr : map->getResults()) { shiftedResults.push_back(getAddExpr(resultExpr, shift)); } return AffineMap::get(map->getNumDims(), map->getNumSymbols(), shiftedResults, - map->getRangeSizes(), context); + map->getRangeSizes()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/IntegerSet.cpp b/mlir/lib/IR/IntegerSet.cpp index 38c3d5138bca..cfd838516dbe 100644 --- a/mlir/lib/IR/IntegerSet.cpp +++ b/mlir/lib/IR/IntegerSet.cpp @@ -23,8 +23,7 @@ using namespace mlir; IntegerSet::IntegerSet(unsigned dimCount, unsigned symbolCount, unsigned numConstraints, - ArrayRef constraints, - ArrayRef eqFlags) + ArrayRef constraints, ArrayRef eqFlags) : dimCount(dimCount), symbolCount(symbolCount), numConstraints(numConstraints), constraints(constraints), eqFlags(eqFlags) {} diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index e3c4f8ee3980..b74f1c8c2c94 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -61,8 +61,8 @@ struct FunctionTypeKeyInfo : DenseMapInfo { struct AffineMapKeyInfo : DenseMapInfo { // Affine maps are uniqued based on their dim/symbol counts and affine // expressions. - using KeyTy = std::tuple, - ArrayRef>; + using KeyTy = std::tuple, + ArrayRef>; using DenseMapInfo::getHashValue; using DenseMapInfo::isEqual; @@ -226,15 +226,15 @@ public: // Affine binary op expression uniquing. Figure out uniquing of dimensional // or symbolic identifiers. - DenseMap, AffineExprRef> + DenseMap, AffineExpr> affineExprs; - // Uniqui'ing of AffineDimExprRef, AffineSymbolExprRef's by their position. - std::vector dimExprs; - std::vector symbolExprs; + // Uniqui'ing of AffineDimExpr, AffineSymbolExpr's by their position. + std::vector dimExprs; + std::vector symbolExprs; - // Uniqui'ing of AffineConstantExpr using constant value as key. - DenseMap constExprs; + // Uniqui'ing of AffineConstantExprClass using constant value as key. + DenseMap constExprs; /// Integer type uniquing. DenseMap integers; @@ -802,15 +802,14 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef attrs, //===----------------------------------------------------------------------===// AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount, - ArrayRef results, - ArrayRef rangeSizes, - MLIRContext *context) { + ArrayRef results, + ArrayRef rangeSizes) { // The number of results can't be zero. assert(!results.empty()); assert(rangeSizes.empty() || results.size() == rangeSizes.size()); - auto &impl = context->getImpl(); + auto &impl = results[0]->getContext()->getImpl(); // Check if we already have this affine map. auto key = std::make_tuple(dimCount, symbolCount, results, rangeSizes); @@ -836,19 +835,18 @@ AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount, } /// Simplify add expression. Return nullptr if it can't be simplified. -static AffineExprRef simplifyAdd(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context) { - auto lhsConst = lhs.dyn_cast(); - auto rhsConst = rhs.dyn_cast(); +static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); // Fold if both LHS, RHS are a constant. if (lhsConst && rhsConst) return getAffineConstantExpr(lhsConst->getValue() + rhsConst->getValue(), - context); + lhs->getContext()); // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4). // If only one of them is a symbolic expressions, make it the RHS. - if (lhs.isa() || + if (lhs.isa() || (lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())) { return rhs + lhs; } @@ -861,16 +859,16 @@ static AffineExprRef simplifyAdd(AffineExprRef lhs, AffineExprRef rhs, return lhs; } // Fold successive additions like (d0 + 2) + 3 into d0 + 5. - auto lBin = lhs.dyn_cast(); + auto lBin = lhs.dyn_cast(); if (lBin && rhsConst && lBin->getKind() == AffineExprKind::Add) { - if (auto lrhs = lBin->getRHS().dyn_cast()) + if (auto lrhs = lBin->getRHS().dyn_cast()) return lBin->getLHS() + (lrhs->getValue() + rhsConst->getValue()); } // When doing successive additions, bring constant to the right: turn (d0 + 2) // + d1 into (d0 + d1) + 2. if (lBin && lBin->getKind() == AffineExprKind::Add) { - if (auto lrhs = lBin->getRHS().dyn_cast()) { + if (auto lrhs = lBin->getRHS().dyn_cast()) { return lBin->getLHS() + rhs + lrhs; } } @@ -879,21 +877,20 @@ static AffineExprRef simplifyAdd(AffineExprRef lhs, AffineExprRef rhs, } /// Simplify a multiply expression. Return nullptr if it can't be simplified. -static AffineExprRef simplifyMul(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context) { - auto lhsConst = lhs.dyn_cast(); - auto rhsConst = rhs.dyn_cast(); +static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) { + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); if (lhsConst && rhsConst) return getAffineConstantExpr(lhsConst->getValue() * rhsConst->getValue(), - context); + lhs->getContext()); assert(lhs->isSymbolicOrConstant() || rhs->isSymbolicOrConstant()); // Canonicalize the mul expression so that the constant/symbolic term is the // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a // constant. (Note that a constant is trivially symbolic). - if (!rhs->isSymbolicOrConstant() || lhs.isa()) { + if (!rhs->isSymbolicOrConstant() || lhs.isa()) { // At least one of them has to be symbolic. return rhs * lhs; } @@ -910,16 +907,16 @@ static AffineExprRef simplifyMul(AffineExprRef lhs, AffineExprRef rhs, } // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6. - auto lBin = lhs.dyn_cast(); + auto lBin = lhs.dyn_cast(); if (lBin && rhsConst && lBin->getKind() == AffineExprKind::Mul) { - if (auto lrhs = lBin->getRHS().dyn_cast()) + if (auto lrhs = lBin->getRHS().dyn_cast()) return lBin->getLHS() * (lrhs->getValue() * rhsConst->getValue()); } // When doing successive multiplication, bring constant to the right: turn (d0 // * 2) * d1 into (d0 * d1) * 2. if (lBin && lBin->getKind() == AffineExprKind::Mul) { - if (auto lrhs = lBin->getRHS().dyn_cast()) { + if (auto lrhs = lBin->getRHS().dyn_cast()) { return (lBin->getLHS() * rhs) * lrhs; } } @@ -927,14 +924,14 @@ static AffineExprRef simplifyMul(AffineExprRef lhs, AffineExprRef rhs, return nullptr; } -static AffineExprRef simplifyFloorDiv(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context) { - auto lhsConst = lhs.dyn_cast(); - auto rhsConst = rhs.dyn_cast(); +static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); if (lhsConst && rhsConst) return getAffineConstantExpr( - floorDiv(lhsConst->getValue(), rhsConst->getValue()), context); + floorDiv(lhsConst->getValue(), rhsConst->getValue()), + lhs->getContext()); // Fold floordiv of a multiply with a constant that is a multiple of the // divisor. Eg: (i * 128) floordiv 64 = i * 2. @@ -942,9 +939,9 @@ static AffineExprRef simplifyFloorDiv(AffineExprRef lhs, AffineExprRef rhs, if (rhsConst->getValue() == 1) return lhs; - auto lBin = lhs.dyn_cast(); + auto lBin = lhs.dyn_cast(); if (lBin && lBin->getKind() == AffineExprKind::Mul) { - if (auto lrhs = lBin->getRHS().dyn_cast()) { + if (auto lrhs = lBin->getRHS().dyn_cast()) { // rhsConst is known to be positive if a constant. if (lrhs->getValue() % rhsConst->getValue() == 0) return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue()); @@ -955,14 +952,13 @@ static AffineExprRef simplifyFloorDiv(AffineExprRef lhs, AffineExprRef rhs, return nullptr; } -static AffineExprRef simplifyCeilDiv(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context) { - auto lhsConst = lhs.dyn_cast(); - auto rhsConst = rhs.dyn_cast(); +static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); if (lhsConst && rhsConst) return getAffineConstantExpr( - ceilDiv(lhsConst->getValue(), rhsConst->getValue()), context); + ceilDiv(lhsConst->getValue(), rhsConst->getValue()), lhs->getContext()); // Fold ceildiv of a multiply with a constant that is a multiple of the // divisor. Eg: (i * 128) ceildiv 64 = i * 2. @@ -970,9 +966,9 @@ static AffineExprRef simplifyCeilDiv(AffineExprRef lhs, AffineExprRef rhs, if (rhsConst->getValue() == 1) return lhs; - auto lBin = lhs.dyn_cast(); + auto lBin = lhs.dyn_cast(); if (lBin && lBin->getKind() == AffineExprKind::Mul) { - if (auto lrhs = lBin->getRHS().dyn_cast()) { + if (auto lrhs = lBin->getRHS().dyn_cast()) { // rhsConst is known to be positive if a constant. if (lrhs->getValue() % rhsConst->getValue() == 0) return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue()); @@ -983,14 +979,13 @@ static AffineExprRef simplifyCeilDiv(AffineExprRef lhs, AffineExprRef rhs, return nullptr; } -static AffineExprRef simplifyMod(AffineExprRef lhs, AffineExprRef rhs, - MLIRContext *context) { - auto lhsConst = lhs.dyn_cast(); - auto rhsConst = rhs.dyn_cast(); +static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); if (lhsConst && rhsConst) return getAffineConstantExpr( - mod(lhsConst->getValue(), rhsConst->getValue()), context); + mod(lhsConst->getValue(), rhsConst->getValue()), lhs->getContext()); // Fold modulo of an expression that is known to be a multiple of a constant // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128) @@ -998,7 +993,7 @@ static AffineExprRef simplifyMod(AffineExprRef lhs, AffineExprRef rhs, if (rhsConst) { // rhsConst is known to be positive if a constant. if (lhs->getLargestKnownDivisor() % rhsConst->getValue() == 0) - return getAffineConstantExpr(0, context); + return getAffineConstantExpr(0, lhs->getContext()); } return nullptr; @@ -1013,33 +1008,33 @@ static AffineExprRef simplifyMod(AffineExprRef lhs, AffineExprRef rhs, /// present, return from the list. The stored expressions are unique: they are /// constructed and stored in a simplified/canonicalized form. The result after /// simplification could be any form of affine expression. -AffineExprRef AffineBinaryOpExpr::get(AffineExprKind kind, AffineExprRef lhs, - AffineExprRef rhs, MLIRContext *context) { - auto &impl = context->getImpl(); +AffineExpr AffineBinaryOpExprClass::get(AffineExprKind kind, AffineExpr lhs, + AffineExpr rhs) { + auto &impl = lhs->getContext()->getImpl(); // Check if we already have this affine expression, and return it if we do. auto keyValue = std::make_tuple((unsigned)kind, lhs, rhs); auto cached = impl.affineExprs.find(keyValue); if (cached != impl.affineExprs.end()) - return static_cast(cached->second); + return static_cast(cached->second); // Simplify the expression if possible. - AffineExprRef simplified; + AffineExpr simplified; switch (kind) { case AffineExprKind::Add: - simplified = simplifyAdd(lhs, rhs, context); + simplified = simplifyAdd(lhs, rhs); break; case AffineExprKind::Mul: - simplified = simplifyMul(lhs, rhs, context); + simplified = simplifyMul(lhs, rhs); break; case AffineExprKind::FloorDiv: - simplified = simplifyFloorDiv(lhs, rhs, context); + simplified = simplifyFloorDiv(lhs, rhs); break; case AffineExprKind::CeilDiv: - simplified = simplifyCeilDiv(lhs, rhs, context); + simplified = simplifyCeilDiv(lhs, rhs); break; case AffineExprKind::Mod: - simplified = simplifyMod(lhs, rhs, context); + simplified = simplifyMod(lhs, rhs); break; default: llvm_unreachable("unexpected binary affine expr"); @@ -1047,20 +1042,20 @@ AffineExprRef AffineBinaryOpExpr::get(AffineExprKind kind, AffineExprRef lhs, // The simplified one would have already been cached; just return it. if (simplified) - return static_cast(simplified); + return static_cast(simplified); // An expression with these operands will already be in the // simplified/canonical form. Create and store it. - auto *result = impl.allocator.Allocate(); + auto *result = impl.allocator.Allocate(); // Initialize the memory using placement new. - new (result) AffineBinaryOpExpr(kind, lhs, rhs, context); + new (result) AffineBinaryOpExprClass(kind, lhs, rhs); bool inserted = impl.affineExprs.insert({keyValue, result}).second; assert(inserted && "the expression shouldn't already exist in the map"); (void)inserted; return result; } -AffineExprRef mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { +AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { auto &impl = context->getImpl(); // Check if we need to resize. @@ -1071,14 +1066,13 @@ AffineExprRef mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { if (result) return result; - result = impl.allocator.Allocate(); + result = impl.allocator.Allocate(); // Initialize the memory using placement new. - new (result) AffineDimExpr(position, context); + new (result) AffineDimExprClass(position, context); return result; } -AffineExprRef mlir::getAffineSymbolExpr(unsigned position, - MLIRContext *context) { +AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) { auto &impl = context->getImpl(); // Check if we need to resize. @@ -1089,23 +1083,22 @@ AffineExprRef mlir::getAffineSymbolExpr(unsigned position, if (result) return result; - result = impl.allocator.Allocate(); + result = impl.allocator.Allocate(); // Initialize the memory using placement new. - new (result) AffineSymbolExpr(position, context); + new (result) AffineSymbolExprClass(position, context); return result; } -AffineExprRef mlir::getAffineConstantExpr(int64_t constant, - MLIRContext *context) { +AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) { auto &impl = context->getImpl(); auto *&result = impl.constExprs[constant]; if (result) return result; - result = impl.allocator.Allocate(); + result = impl.allocator.Allocate(); // Initialize the memory using placement new. - new (result) AffineConstantExpr(constant, context); + new (result) AffineConstantExprClass(constant, context); return result; } @@ -1115,7 +1108,7 @@ AffineExprRef mlir::getAffineConstantExpr(int64_t constant, //===----------------------------------------------------------------------===// IntegerSet *IntegerSet::get(unsigned dimCount, unsigned symbolCount, - ArrayRef constraints, + ArrayRef constraints, ArrayRef eqFlags, MLIRContext *context) { assert(eqFlags.size() == constraints.size()); @@ -1125,7 +1118,7 @@ IntegerSet *IntegerSet::get(unsigned dimCount, unsigned symbolCount, auto *res = impl.allocator.Allocate(); // Copy the equalities and inequalities into the bump pointer. - constraints = impl.copyInto(ArrayRef(constraints)); + constraints = impl.copyInto(ArrayRef(constraints)); eqFlags = impl.copyInto(ArrayRef(eqFlags)); // Initialize the memory using placement new. diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index f560841684e5..e6cfc1ec75dd 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -831,38 +831,35 @@ private: // Identifier lists for polyhedral structures. ParseResult parseDimIdList(unsigned &numDims); ParseResult parseSymbolIdList(unsigned &numSymbols); - ParseResult parseIdentifierDefinition(AffineExprRef idExpr); + ParseResult parseIdentifierDefinition(AffineExpr idExpr); - AffineExprRef parseAffineExpr(); - AffineExprRef parseParentheticalExpr(); - AffineExprRef parseNegateExpression(AffineExprRef lhs); - AffineExprRef parseIntegerExpr(); - AffineExprRef parseBareIdExpr(); + AffineExpr parseAffineExpr(); + AffineExpr parseParentheticalExpr(); + AffineExpr parseNegateExpression(AffineExpr lhs); + AffineExpr parseIntegerExpr(); + AffineExpr parseBareIdExpr(); - AffineExprRef getBinaryAffineOpExpr(AffineHighPrecOp op, AffineExprRef lhs, - AffineExprRef rhs, SMLoc opLoc); - AffineExprRef getBinaryAffineOpExpr(AffineLowPrecOp op, AffineExprRef lhs, - AffineExprRef rhs); - AffineExprRef parseAffineOperandExpr(AffineExprRef lhs); - AffineExprRef parseAffineLowPrecOpExpr(AffineExprRef llhs, - AffineLowPrecOp llhsOp); - AffineExprRef parseAffineHighPrecOpExpr(AffineExprRef llhs, - AffineHighPrecOp llhsOp, - SMLoc llhsOpLoc); - AffineExprRef parseAffineConstraint(bool *isEq); + AffineExpr getBinaryAffineOpExpr(AffineHighPrecOp op, AffineExpr lhs, + AffineExpr rhs, SMLoc opLoc); + AffineExpr getBinaryAffineOpExpr(AffineLowPrecOp op, AffineExpr lhs, + AffineExpr rhs); + AffineExpr parseAffineOperandExpr(AffineExpr lhs); + AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); + AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, + SMLoc llhsOpLoc); + AffineExpr parseAffineConstraint(bool *isEq); private: - SmallVector, 4> dimsAndSymbols; + SmallVector, 4> dimsAndSymbols; }; } // end anonymous namespace /// Create an affine binary high precedence op expression (mul's, div's, mod). /// opLoc is the location of the op token to be used to report errors /// for non-conforming expressions. -AffineExprRef AffineParser::getBinaryAffineOpExpr(AffineHighPrecOp op, - AffineExprRef lhs, - AffineExprRef rhs, - SMLoc opLoc) { +AffineExpr AffineParser::getBinaryAffineOpExpr(AffineHighPrecOp op, + AffineExpr lhs, AffineExpr rhs, + SMLoc opLoc) { // TODO: make the error location info accurate. switch (op) { case Mul: @@ -900,9 +897,8 @@ AffineExprRef AffineParser::getBinaryAffineOpExpr(AffineHighPrecOp op, } /// Create an affine binary low precedence op expression (add, sub). -AffineExprRef AffineParser::getBinaryAffineOpExpr(AffineLowPrecOp op, - AffineExprRef lhs, - AffineExprRef rhs) { +AffineExpr AffineParser::getBinaryAffineOpExpr(AffineLowPrecOp op, + AffineExpr lhs, AffineExpr rhs) { switch (op) { case AffineLowPrecOp::Add: return builder.getAddExpr(lhs, rhs); @@ -960,10 +956,10 @@ AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is /// null. llhsOpLoc is the location of the llhsOp token that will be used to /// report an error for non-conforming expressions. -AffineExprRef AffineParser::parseAffineHighPrecOpExpr(AffineExprRef llhs, - AffineHighPrecOp llhsOp, - SMLoc llhsOpLoc) { - AffineExprRef lhs = parseAffineOperandExpr(llhs); +AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, + AffineHighPrecOp llhsOp, + SMLoc llhsOpLoc) { + AffineExpr lhs = parseAffineOperandExpr(llhs); if (!lhs) return nullptr; @@ -971,7 +967,7 @@ AffineExprRef AffineParser::parseAffineHighPrecOpExpr(AffineExprRef llhs, auto opLoc = getToken().getLoc(); if (AffineHighPrecOp op = consumeIfHighPrecOp()) { if (llhs) { - AffineExprRef expr = getBinaryAffineOpExpr(llhsOp, llhs, lhs, opLoc); + AffineExpr expr = getBinaryAffineOpExpr(llhsOp, llhs, lhs, opLoc); if (!expr) return nullptr; return parseAffineHighPrecOpExpr(expr, op, opLoc); @@ -991,7 +987,7 @@ AffineExprRef AffineParser::parseAffineHighPrecOpExpr(AffineExprRef llhs, /// Parse an affine expression inside parentheses. /// /// affine-expr ::= `(` affine-expr `)` -AffineExprRef AffineParser::parseParentheticalExpr() { +AffineExpr AffineParser::parseParentheticalExpr() { if (parseToken(Token::l_paren, "expected '('")) return nullptr; if (getToken().is(Token::r_paren)) @@ -1009,11 +1005,11 @@ AffineExprRef AffineParser::parseParentheticalExpr() { /// Parse the negation expression. /// /// affine-expr ::= `-` affine-expr -AffineExprRef AffineParser::parseNegateExpression(AffineExprRef lhs) { +AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { if (parseToken(Token::minus, "expected '-'")) return nullptr; - AffineExprRef operand = parseAffineOperandExpr(lhs); + AffineExpr operand = parseAffineOperandExpr(lhs); // Since negation has the highest precedence of all ops (including high // precedence ops) but lower than parentheses, we are only going to use // parseAffineOperandExpr instead of parseAffineExpr here. @@ -1028,7 +1024,7 @@ AffineExprRef AffineParser::parseNegateExpression(AffineExprRef lhs) { /// Parse a bare id that may appear in an affine expression. /// /// affine-expr ::= bare-id -AffineExprRef AffineParser::parseBareIdExpr() { +AffineExpr AffineParser::parseBareIdExpr() { if (getToken().isNot(Token::bare_identifier)) return (emitError("expected bare identifier"), nullptr); @@ -1046,7 +1042,7 @@ AffineExprRef AffineParser::parseBareIdExpr() { /// Parse a positive integral constant appearing in an affine expression. /// /// affine-expr ::= integer-literal -AffineExprRef AffineParser::parseIntegerExpr() { +AffineExpr AffineParser::parseIntegerExpr() { auto val = getToken().getUInt64IntegerValue(); if (!val.hasValue() || (int64_t)val.getValue() < 0) return (emitError("constant too large for index"), nullptr); @@ -1064,7 +1060,7 @@ AffineExprRef AffineParser::parseIntegerExpr() { // operand expression, it's an op expression and will be parsed via // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and -l // are valid operands that will be parsed by this function. -AffineExprRef AffineParser::parseAffineOperandExpr(AffineExprRef lhs) { +AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { switch (getToken().getKind()) { case Token::bare_identifier: return parseBareIdExpr(); @@ -1114,16 +1110,16 @@ AffineExprRef AffineParser::parseAffineOperandExpr(AffineExprRef lhs) { /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where (e2*e3) /// will be parsed using parseAffineHighPrecOpExpr(). -AffineExprRef AffineParser::parseAffineLowPrecOpExpr(AffineExprRef llhs, - AffineLowPrecOp llhsOp) { - AffineExprRef lhs; +AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, + AffineLowPrecOp llhsOp) { + AffineExpr lhs; if (!(lhs = parseAffineOperandExpr(llhs))) return nullptr; // Found an LHS. Deal with the ops. if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { if (llhs) { - AffineExprRef sum = getBinaryAffineOpExpr(llhsOp, llhs, lhs); + AffineExpr sum = getBinaryAffineOpExpr(llhsOp, llhs, lhs); return parseAffineLowPrecOpExpr(sum, lOp); } // No LLHS, get RHS and form the expression. @@ -1133,13 +1129,13 @@ AffineExprRef AffineParser::parseAffineLowPrecOpExpr(AffineExprRef llhs, if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { // We have a higher precedence op here. Get the rhs operand for the llhs // through parseAffineHighPrecOpExpr. - AffineExprRef highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); + AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); if (!highRes) return nullptr; // If llhs is null, the product forms the first operand of the yet to be // found expression. If non-null, the op to associate with llhs is llhsOp. - AffineExprRef expr = + AffineExpr expr = llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes) : highRes; // Recurse for subsequent low prec op's after the affine high prec op @@ -1170,14 +1166,14 @@ AffineExprRef AffineParser::parseAffineLowPrecOpExpr(AffineExprRef llhs, /// Additional conditions are checked depending on the production. For eg., one /// of the operands for `*` has to be either constant/symbolic; the second /// operand for floordiv, ceildiv, and mod has to be a positive integer. -AffineExprRef AffineParser::parseAffineExpr() { +AffineExpr AffineParser::parseAffineExpr() { return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); } /// Parse a dim or symbol from the lists appearing before the actual expressions /// of the affine map. Update our state to store the dimensional/symbolic /// identifier. -ParseResult AffineParser::parseIdentifierDefinition(AffineExprRef idExpr) { +ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { if (getToken().isNot(Token::bare_identifier)) return emitError("expected bare identifier"); @@ -1239,7 +1235,7 @@ AffineMap *AffineParser::parseAffineMapInline() { parseToken(Token::l_paren, "expected '(' at start of affine map range")) return nullptr; - SmallVector exprs; + SmallVector exprs; auto parseElt = [&]() -> ParseResult { auto elt = parseAffineExpr(); ParseResult res = elt ? ParseSuccess : ParseFailure; @@ -1258,7 +1254,7 @@ AffineMap *AffineParser::parseAffineMapInline() { // dim-size ::= affine-expr | `min` `(` affine-expr (`,` affine-expr)+ `)` // TODO(bondhugula): support for min of several affine expressions. // TODO: check if sizes are non-negative whenever they are constant. - SmallVector rangeSizes; + SmallVector rangeSizes; if (consumeIf(Token::kw_size)) { // Location of the l_paren token (if it exists) for error reporting later. auto loc = getToken().getLoc(); @@ -2446,8 +2442,8 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl &operands, /// isEq is set to true if the parsed constraint is an equality, false if it is /// an inequality (greater than or equal). /// -AffineExprRef AffineParser::parseAffineConstraint(bool *isEq) { - AffineExprRef expr = parseAffineExpr(); +AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { + AffineExpr expr = parseAffineExpr(); if (!expr) return nullptr; @@ -2501,7 +2497,7 @@ IntegerSet *AffineParser::parseIntegerSetInline() { "expected '(' at start of integer set constraint list")) return nullptr; - SmallVector constraints; + SmallVector constraints; SmallVector isEqs; auto parseElt = [&]() -> ParseResult { bool isEq; diff --git a/mlir/lib/Transforms/LoopUtils.cpp b/mlir/lib/Transforms/LoopUtils.cpp index 79dcb2ddfeb6..89e59fe40fb0 100644 --- a/mlir/lib/Transforms/LoopUtils.cpp +++ b/mlir/lib/Transforms/LoopUtils.cpp @@ -49,7 +49,7 @@ AffineMap *mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt, if (!tripCount) return nullptr; - AffineExprRef lb(lbMap->getResult(0)); + AffineExpr lb(lbMap->getResult(0)); unsigned step = forStmt.getStep(); auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step; @@ -71,11 +71,11 @@ AffineMap *mlir::getCleanupLoopLowerBound(const ForStmt &forStmt, return nullptr; // Sometimes the trip count cannot be expressed as an affine expression. - AffineExprRef tripCount(getTripCountExpr(forStmt)); + AffineExpr tripCount(getTripCountExpr(forStmt)); if (!tripCount) return nullptr; - AffineExprRef lb(lbMap->getResult(0)); + AffineExpr lb(lbMap->getResult(0)); unsigned step = forStmt.getStep(); auto newLb = lb + (tripCount - tripCount % unrollFactor) * step; return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),