diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index d511f628c3c2..b9def6cb24f7 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -28,13 +28,230 @@ #include "mlir/IR/StandardTypes.h" namespace mlir { +class AffineBound; class AffineOpsDialect : public Dialect { public: AffineOpsDialect(MLIRContext *context); }; -/// The "if" operation represents an if–then–else construct for conditionally +/// The "for" instruction represents an affine loop nest, defining an SSA value +/// for its induction variable. The induction variable is represented as a +/// BlockArgument to the entry block of the body. The body and induction +/// variable can be created automatically for new "for" ops with 'createBody'. +/// This SSA value always has type index, which is the size of the machine word. +/// The stride, represented by step, is a positive constant integer which +/// defaults to "1" if not present. The lower and upper bounds specify a +/// half-open range: the range includes the lower bound but does not include the +/// upper bound. +/// +/// The lower and upper bounds of a for operation are represented as an +/// application of an affine mapping to a list of SSA values passed to the map. +/// The same restrictions hold for these SSA values as for all bindings of SSA +/// values to dimensions and symbols. The affine mappings for the bounds may +/// return multiple results, in which case the max/min keywords are required +/// (for the lower/upper bound respectively), and the bound is the +/// maximum/minimum of the returned values. +/// +/// Example: +/// +/// for %i = 1 to 10 { +/// ... +/// } +/// +class AffineForOp + : public Op { +public: + // Hooks to customize behavior of this op. + static void build(Builder *builder, OperationState *result, + ArrayRef lbOperands, AffineMap lbMap, + ArrayRef ubOperands, AffineMap ubMap, + int64_t step = 1); + static void build(Builder *builder, OperationState *result, int64_t lb, + int64_t ub, int64_t step = 1); + bool verify() const; + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + + static StringRef getOperationName() { return "for"; } + static StringRef getStepAttrName() { return "step"; } + static StringRef getLowerBoundAttrName() { return "lower_bound"; } + static StringRef getUpperBoundAttrName() { return "upper_bound"; } + + /// Generate a body block for this AffineForOp. The operation must not already + /// have a body. The operation must contain a parent function. + Block *createBody(); + + /// Get the body of the AffineForOp. + Block *getBody() { return &getBlockList().front(); } + const Block *getBody() const { return &getBlockList().front(); } + + /// Get the blocklist containing the body. + BlockList &getBlockList() { return getInstruction()->getBlockList(0); } + const BlockList &getBlockList() const { + return getInstruction()->getBlockList(0); + } + + /// Returns the induction variable for this loop. + Value *getInductionVar(); + const Value *getInductionVar() const { + return const_cast(this)->getInductionVar(); + } + + //===--------------------------------------------------------------------===// + // Bounds and step + //===--------------------------------------------------------------------===// + + using operand_range = llvm::iterator_range; + using const_operand_range = llvm::iterator_range; + + // TODO: provide iterators for the lower and upper bound operands + // if the current access via getLowerBound(), getUpperBound() is too slow. + + /// Returns operands for the lower bound map. + operand_range getLowerBoundOperands(); + const_operand_range getLowerBoundOperands() const; + + /// Returns operands for the upper bound map. + operand_range getUpperBoundOperands(); + const_operand_range getUpperBoundOperands() const; + + /// Returns information about the lower bound as a single object. + const AffineBound getLowerBound() const; + + /// Returns information about the upper bound as a single object. + const AffineBound getUpperBound() const; + + /// Returns loop step. + int64_t getStep() const { + return getAttr(getStepAttrName()).cast().getInt(); + } + + /// Returns affine map for the lower bound. + AffineMap getLowerBoundMap() const { + return getAttr(getLowerBoundAttrName()).cast().getValue(); + } + /// Returns affine map for the upper bound. The upper bound is exclusive. + AffineMap getUpperBoundMap() const { + return getAttr(getUpperBoundAttrName()).cast().getValue(); + } + + /// Set lower bound. The new bound must have the same number of operands as + /// the current bound map. Otherwise, 'replaceForLowerBound' should be used. + void setLowerBound(ArrayRef operands, AffineMap map); + /// Set upper bound. The new bound must not have more operands than the + /// current bound map. Otherwise, 'replaceForUpperBound' should be used. + void setUpperBound(ArrayRef operands, AffineMap map); + + /// Set the lower bound map without changing operands. + void setLowerBoundMap(AffineMap map); + + /// Set the upper bound map without changing operands. + void setUpperBoundMap(AffineMap map); + + /// Set loop step. + void setStep(int64_t step) { + assert(step > 0 && "step has to be a positive integer constant"); + auto *context = getLowerBoundMap().getContext(); + setAttr(Identifier::get(getStepAttrName(), context), + IntegerAttr::get(IndexType::get(context), step)); + } + + /// Returns true if the lower bound is constant. + bool hasConstantLowerBound() const; + /// Returns true if the upper bound is constant. + bool hasConstantUpperBound() const; + /// Returns true if both bounds are constant. + bool hasConstantBounds() const { + return hasConstantLowerBound() && hasConstantUpperBound(); + } + /// Returns the value of the constant lower bound. + /// Fails assertion if the bound is non-constant. + int64_t getConstantLowerBound() const; + /// Returns the value of the constant upper bound. The upper bound is + /// exclusive. Fails assertion if the bound is non-constant. + int64_t getConstantUpperBound() const; + /// Sets the lower bound to the given constant value. + void setConstantLowerBound(int64_t value); + /// Sets the upper bound to the given constant value. + void setConstantUpperBound(int64_t value); + + /// Returns true if both the lower and upper bound have the same operand lists + /// (same operands in the same order). + bool matchingBoundOperandList() const; + + /// Walk the operation instructions in the 'for' instruction in preorder, + /// calling the callback for each operation. + void walkOps(std::function callback); + + /// Walk the operation instructions in the 'for' instruction in postorder, + /// calling the callback for each operation. + void walkOpsPostOrder(std::function callback); + +private: + friend class OperationInst; + explicit AffineForOp(const OperationInst *state) : Op(state) {} +}; + +/// Returns if the provided value is the induction variable of a AffineForOp. +bool isForInductionVar(const Value *val); + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +OpPointer getForInductionVarOwner(Value *val); +ConstOpPointer getForInductionVarOwner(const Value *val); + +/// Extracts the induction variables from a list of AffineForOps and returns +/// them. +SmallVector +extractForInductionVars(MutableArrayRef> forInsts); + +/// AffineBound represents a lower or upper bound in the for instruction. +/// This class does not own the underlying operands. Instead, it refers +/// to the operands stored in the AffineForOp. Its life span should not exceed +/// that of the for instruction it refers to. +class AffineBound { +public: + ConstOpPointer getAffineForOp() const { return inst; } + AffineMap getMap() const { return map; } + + unsigned getNumOperands() const { return opEnd - opStart; } + const Value *getOperand(unsigned idx) const { + return inst->getInstruction()->getOperand(opStart + idx); + } + + using operand_iterator = AffineForOp::operand_iterator; + using operand_range = AffineForOp::operand_range; + + operand_iterator operand_begin() const { + return const_cast(inst->getInstruction()) + ->operand_begin() + + opStart; + } + operand_iterator operand_end() const { + return const_cast(inst->getInstruction()) + ->operand_begin() + + opEnd; + } + operand_range getOperands() const { return {operand_begin(), operand_end()}; } + +private: + // 'for' instruction that contains this bound. + ConstOpPointer inst; + // Start and end positions of this affine bound operands in the list of + // the containing 'for' instruction operands. + unsigned opStart, opEnd; + // Affine map for this bound. + AffineMap map; + + AffineBound(ConstOpPointer inst, unsigned opStart, + unsigned opEnd, AffineMap map) + : inst(inst), opStart(opStart), opEnd(opEnd), map(map) {} + + friend class AffineForOp; +}; + +/// The "if" operation represents an if-then-else construct for conditionally /// executing two regions of code. The operands to an if operation are an /// IntegerSet condition and a set of symbol/dimension operands to the /// condition set. The operation produces no results. For example: diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 30576b587a00..3ee35eea2ff5 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -32,10 +32,10 @@ namespace mlir { class AffineApplyOp; class AffineExpr; +class AffineForOp; class AffineMap; class AffineValueMap; class FlatAffineConstraints; -class ForInst; class FuncBuilder; class Instruction; class IntegerSet; @@ -108,12 +108,12 @@ bool getFlattenedAffineExprs( FlatAffineConstraints *cst = nullptr); /// Builds a system of constraints with dimensional identifiers corresponding to -/// the loop IVs of the forInsts appearing in that order. Bounds of the loop are +/// the loop IVs of the forOps appearing in that order. Bounds of the loop are /// used to add appropriate inequalities. Any symbols founds in the bound /// operands are added as symbols in the system. Returns false for the yet /// unimplemented cases. // TODO(bondhugula): handle non-unit strides. -bool getIndexSet(llvm::ArrayRef forInsts, +bool getIndexSet(llvm::MutableArrayRef> forOps, FlatAffineConstraints *domain); /// Encapsulates a memref load or store access information. diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 49202eb7cc51..e8b4ee623c0f 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -28,9 +28,10 @@ namespace mlir { class AffineApplyOp; class AffineBound; +class AffineForOp; class AffineCondition; class AffineMap; -class ForInst; +template class ConstOpPointer; class IntegerSet; class MLIRContext; class Value; @@ -113,13 +114,12 @@ private: /// results, and its map can themselves change as a result of /// substitutions, simplifications, and other analysis. // An affine value map can readily be constructed from an AffineApplyOp, or an -// AffineBound of a ForInst. It can be further transformed, substituted into, -// or simplified. Unlike AffineMap's, AffineValueMap's are created and destroyed -// during analysis. Only the AffineMap expressions that are pointed by them are -// unique'd. -// An affine value map, and the operations on it, maintain the invariant that -// operands are always positionally aligned with the AffineDimExpr and -// AffineSymbolExpr in the underlying AffineMap. +// AffineBound of a AffineForOp. It can be further transformed, substituted +// into, or simplified. Unlike AffineMap's, AffineValueMap's are created and +// destroyed during analysis. Only the AffineMap expressions that are pointed by +// them are unique'd. An affine value map, and the operations on it, maintain +// the invariant that operands are always positionally aligned with the +// AffineDimExpr and AffineSymbolExpr in the underlying AffineMap. // TODO(bondhugula): Some of these classes could go into separate files. class AffineValueMap { public: @@ -173,9 +173,6 @@ private: // Both, the integer set being pointed to and the operands can change during // analysis, simplification, and transformation. class IntegerValueSet { - // Constructs an integer value set map from an IntegerSet and operands. - explicit IntegerValueSet(const AffineCondition &cond); - /// Constructs an integer value set from an affine value map. // This will lead to a single equality in 'set'. explicit IntegerValueSet(const AffineValueMap &avm); @@ -403,7 +400,7 @@ public: /// Adds constraints (lower and upper bounds) for the specified 'for' /// instruction's Value using IR information stored in its bound maps. The - /// right identifier is first looked up using forInst's Value. Returns + /// right identifier is first looked up using forOp's Value. Returns /// false for the yet unimplemented/unsupported cases, and true if the /// information is succesfully added. Asserts if the Value corresponding to /// the 'for' instruction isn't found in the constraint system. Any new @@ -411,7 +408,7 @@ public: /// are added as trailing identifiers (either dimensional or symbolic /// depending on whether the operand is a valid ML Function symbol). // TODO(bondhugula): add support for non-unit strides. - bool addForInstDomain(const ForInst &forInst); + bool addAffineForOpDomain(ConstOpPointer forOp); /// Adds a constant lower bound constraint for the specified expression. void addConstantLowerBound(ArrayRef expr, int64_t lb); diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index 1b3d0ce96752..16c1c9673852 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -29,8 +29,9 @@ namespace mlir { class AffineExpr; +class AffineForOp; class AffineMap; -class ForInst; +template class ConstOpPointer; class MemRefType; class OperationInst; class Value; @@ -38,19 +39,20 @@ class Value; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count /// expression is simplified before returning. -AffineExpr getTripCountExpr(const ForInst &forInst); +AffineExpr getTripCountExpr(ConstOpPointer forOp); /// Returns the trip count of the loop if it's a constant, None otherwise. This /// uses affine expression analysis and is able to determine constant trip count /// in non-trivial cases. -llvm::Optional getConstantTripCount(const ForInst &forInst); +llvm::Optional +getConstantTripCount(ConstOpPointer forOp); /// Returns the greatest known integral divisor of the trip count. Affine /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. -uint64_t getLargestDivisorOfTripCount(const ForInst &forInst); +uint64_t getLargestDivisorOfTripCount(ConstOpPointer forOp); -/// Given an induction variable `iv` of type ForInst and an `index` of type +/// Given an induction variable `iv` of type AffineForOp and an `index` of type /// IndexType, returns `true` if `index` is independent of `iv` and false /// otherwise. /// The determination supports composition with at most one AffineApplyOp. @@ -67,7 +69,7 @@ uint64_t getLargestDivisorOfTripCount(const ForInst &forInst); /// conservative. bool isAccessInvariant(const Value &iv, const Value &index); -/// Given an induction variable `iv` of type ForInst and `indices` of type +/// Given an induction variable `iv` of type AffineForOp and `indices` of type /// IndexType, returns the set of `indices` that are independent of `iv`. /// /// Prerequisites (inherited from `isAccessInvariant` above): @@ -85,21 +87,21 @@ getInvariantAccesses(const Value &iv, llvm::ArrayRef indices); /// 3. all nested load/stores are to scalar MemRefs. /// TODO(ntv): implement dependence semantics /// TODO(ntv): relax the no-conditionals restriction -bool isVectorizableLoop(const ForInst &loop); +bool isVectorizableLoop(ConstOpPointer loop); /// Checks whether the loop is structurally vectorizable and that all the LoadOp /// and StoreOp matched have access indexing functions that are are either: /// 1. invariant along the loop induction variable created by 'loop'; /// 2. varying along the 'fastestVaryingDim' memory dimension. -bool isVectorizableLoopAlongFastestVaryingMemRefDim(const ForInst &loop, - unsigned fastestVaryingDim); +bool isVectorizableLoopAlongFastestVaryingMemRefDim( + ConstOpPointer loop, unsigned fastestVaryingDim); /// Checks where SSA dominance would be violated if a for inst's body /// instructions are shifted by the specified shifts. This method checks if a /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. -bool isInstwiseShiftValid(const ForInst &forInst, +bool isInstwiseShiftValid(ConstOpPointer forOp, llvm::ArrayRef shifts); } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 2a1c469348de..0e41058f7773 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -127,7 +127,6 @@ private: struct State : public InstWalker { State(NestedPattern &pattern, SmallVectorImpl *matches) : pattern(pattern), matches(matches) {} - void visitForInst(ForInst *forInst) { pattern.matchOne(forInst, matches); } void visitOperationInst(OperationInst *opInst) { pattern.matchOne(opInst, matches); } diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index e9de4a8d2590..bb81df604cfe 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -33,10 +33,12 @@ namespace mlir { +class AffineForOp; +template class ConstOpPointer; class FlatAffineConstraints; -class ForInst; class MemRefAccess; class OperationInst; +template class OpPointer; class Instruction; class Value; @@ -49,7 +51,8 @@ bool properlyDominates(const Instruction &a, const Instruction &b); /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. // TODO(bondhugula): handle 'if' inst's. -void getLoopIVs(const Instruction &inst, SmallVectorImpl *loops); +void getLoopIVs(const Instruction &inst, + SmallVectorImpl> *loops); /// Returns the nesting depth of this instruction, i.e., the number of loops /// surrounding this instruction. @@ -191,12 +194,12 @@ bool getBackwardComputationSliceState(const MemRefAccess &srcAccess, // materialize the results of the backward slice - presenting a trade-off b/w // storage and redundant computation in several cases. // TODO(andydavis) Support computation slices with common surrounding loops. -ForInst *insertBackwardComputationSlice(OperationInst *srcOpInst, - OperationInst *dstOpInst, - unsigned dstLoopDepth, - ComputationSliceState *sliceState); +OpPointer +insertBackwardComputationSlice(OperationInst *srcOpInst, + OperationInst *dstOpInst, unsigned dstLoopDepth, + ComputationSliceState *sliceState); -Optional getMemoryFootprintBytes(const ForInst &forInst, +Optional getMemoryFootprintBytes(ConstOpPointer forOp, int memorySpace = -1); } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index dfb1164750a3..89f49fdfe778 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -25,8 +25,8 @@ namespace mlir { class AffineApplyOp; +class AffineForOp; class AffineMap; -class ForInst; class FuncBuilder; class Instruction; class Location; @@ -71,7 +71,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// loop information is extracted. /// /// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at -/// most one invariant index along each ForInst of `loopToVectorDim`). +/// most one invariant index along each AffineForOp of `loopToVectorDim`). /// /// Example 1: /// The following MLIR snippet: @@ -122,9 +122,9 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// Meaning that vector_transfer_read will be responsible of reading the slice /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. /// -AffineMap -makePermutationMap(OperationInst *opInst, - const llvm::DenseMap &loopToVectorDim); +AffineMap makePermutationMap( + OperationInst *opInst, + const llvm::DenseMap &loopToVectorDim); namespace matcher { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 3271c12afde4..29a9fb0281b0 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -239,11 +239,6 @@ public: /// current function. Block *createBlock(Block *insertBefore = nullptr); - /// Returns a builder for the body of a 'for' instruction. - static FuncBuilder getForInstBodyBuilder(ForInst *forInst) { - return FuncBuilder(forInst->getBody(), forInst->getBody()->end()); - } - /// Returns the current block of the builder. Block *getBlock() const { return block; } @@ -277,15 +272,6 @@ public: return cloneInst; } - // Creates a for instruction. When step is not specified, it is set to 1. - ForInst *createFor(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, - AffineMap ubMap, int64_t step = 1); - - // Creates a for instruction with known (constant) lower and upper bounds. - // Default step is 1. - ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1); - private: Function *function; Block *block = nullptr; diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h index 78810da909d1..0ed8599ff334 100644 --- a/mlir/include/mlir/IR/InstVisitor.h +++ b/mlir/include/mlir/IR/InstVisitor.h @@ -83,8 +83,6 @@ public: "Must pass the derived type to this template!"); switch (s->getKind()) { - case Instruction::Kind::For: - return static_cast(this)->visitForInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->visitOperationInst( cast(s)); @@ -101,7 +99,6 @@ public: // When visiting a for inst, if inst, or an operation inst directly, these // methods get called to indicate when transitioning into a new unit. - void visitForInst(ForInst *forInst) {} void visitOperationInst(OperationInst *opInst) {} }; @@ -147,22 +144,11 @@ public: void walkOpInstPostOrder(OperationInst *opInst) { for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) - static_cast(this)->walk(block.begin(), block.end()); + static_cast(this)->walkPostOrder(block.begin(), + block.end()); static_cast(this)->visitOperationInst(opInst); } - void walkForInst(ForInst *forInst) { - static_cast(this)->visitForInst(forInst); - auto *body = forInst->getBody(); - static_cast(this)->walk(body->begin(), body->end()); - } - - void walkForInstPostOrder(ForInst *forInst) { - auto *body = forInst->getBody(); - static_cast(this)->walkPostOrder(body->begin(), body->end()); - static_cast(this)->visitForInst(forInst); - } - // Function to walk a instruction. RetTy walk(Instruction *s) { static_assert(std::is_base_of::value, @@ -171,8 +157,6 @@ public: static_cast(this)->visitInstruction(s); switch (s->getKind()) { - case Instruction::Kind::For: - return static_cast(this)->walkForInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInst(cast(s)); } @@ -185,9 +169,6 @@ public: static_cast(this)->visitInstruction(s); switch (s->getKind()) { - case Instruction::Kind::For: - return static_cast(this)->walkForInstPostOrder( - cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInstPostOrder( cast(s)); @@ -205,7 +186,6 @@ public: // called. These are typically O(1) complexity and shouldn't be recursively // processing their descendants in some way. When using RetTy, all of these // need to be overridden. - void visitForInst(ForInst *forInst) {} void visitOperationInst(OperationInst *opInst) {} void visitInstruction(Instruction *inst) {} }; diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index 3dc1e76dd20d..3789fefc639d 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -32,7 +32,6 @@ namespace mlir { class Block; class BlockAndValueMapping; class Location; -class ForInst; class MLIRContext; /// Terminator operations can have Block operands to represent successors. @@ -74,7 +73,6 @@ class Instruction : public IROperandOwner, public: enum class Kind { OperationInst = (int)IROperandOwner::Kind::OperationInst, - For = (int)IROperandOwner::Kind::ForInst, }; Kind getKind() const { return (Kind)IROperandOwner::getKind(); } diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index 724c4dd70390..c6fde0e0aee1 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -26,15 +26,11 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Block.h" #include "mlir/IR/Instruction.h" -#include "mlir/IR/IntegerSet.h" #include "mlir/IR/OperationSupport.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/TrailingObjects.h" namespace mlir { -class AffineBound; -class IntegerSet; -class AffineCondition; class AttributeListStorage; template class ConstOpPointer; template class OpPointer; @@ -219,6 +215,13 @@ public: return getOperandStorage().isResizable(); } + /// Replace the current operands of this operation with the ones provided in + /// 'operands'. If the operands list is not resizable, the size of 'operands' + /// must be less than or equal to the current number of operands. + void setOperands(ArrayRef operands) { + getOperandStorage().setOperands(this, operands); + } + unsigned getNumOperands() const { return getOperandStorage().size(); } Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } @@ -697,262 +700,6 @@ inline auto OperationInst::getResultTypes() const return {result_type_begin(), result_type_end()}; } -/// For instruction represents an affine loop nest. -class ForInst final - : public Instruction, - private llvm::TrailingObjects { -public: - static ForInst *create(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, - AffineMap ubMap, int64_t step); - - /// Resolve base class ambiguity. - using Instruction::getFunction; - - /// Operand iterators. - using operand_iterator = OperandIterator; - using const_operand_iterator = OperandIterator; - - /// Operand iterator range. - using operand_range = llvm::iterator_range; - using const_operand_range = llvm::iterator_range; - - /// Get the body of the ForInst. - Block *getBody() { return &body.front(); } - - /// Get the body of the ForInst. - const Block *getBody() const { return &body.front(); } - - //===--------------------------------------------------------------------===// - // Bounds and step - //===--------------------------------------------------------------------===// - - /// Returns information about the lower bound as a single object. - const AffineBound getLowerBound() const; - - /// Returns information about the upper bound as a single object. - const AffineBound getUpperBound() const; - - /// Returns loop step. - int64_t getStep() const { return step; } - - /// Returns affine map for the lower bound. - AffineMap getLowerBoundMap() const { return lbMap; } - /// Returns affine map for the upper bound. The upper bound is exclusive. - AffineMap getUpperBoundMap() const { return ubMap; } - - /// Set lower bound. - void setLowerBound(ArrayRef operands, AffineMap map); - /// Set upper bound. - void setUpperBound(ArrayRef operands, AffineMap map); - - /// Set the lower bound map without changing operands. - void setLowerBoundMap(AffineMap map); - - /// Set the upper bound map without changing operands. - void setUpperBoundMap(AffineMap map); - - /// Set loop step. - void setStep(int64_t step) { - assert(step > 0 && "step has to be a positive integer constant"); - this->step = step; - } - - /// Returns true if the lower bound is constant. - bool hasConstantLowerBound() const; - /// Returns true if the upper bound is constant. - bool hasConstantUpperBound() const; - /// Returns true if both bounds are constant. - bool hasConstantBounds() const { - return hasConstantLowerBound() && hasConstantUpperBound(); - } - /// Returns the value of the constant lower bound. - /// Fails assertion if the bound is non-constant. - int64_t getConstantLowerBound() const; - /// Returns the value of the constant upper bound. The upper bound is - /// exclusive. Fails assertion if the bound is non-constant. - int64_t getConstantUpperBound() const; - /// Sets the lower bound to the given constant value. - void setConstantLowerBound(int64_t value); - /// Sets the upper bound to the given constant value. - void setConstantUpperBound(int64_t value); - - /// Returns true if both the lower and upper bound have the same operand lists - /// (same operands in the same order). - bool matchingBoundOperandList() const; - - /// Walk the operation instructions in the 'for' instruction in preorder, - /// calling the callback for each operation. - void walkOps(std::function callback); - - /// Walk the operation instructions in the 'for' instruction in postorder, - /// calling the callback for each operation. - void walkOpsPostOrder(std::function callback); - - //===--------------------------------------------------------------------===// - // Operands - //===--------------------------------------------------------------------===// - - unsigned getNumOperands() const { return getOperandStorage().size(); } - - Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } - const Value *getOperand(unsigned idx) const { - return getInstOperand(idx).get(); - } - void setOperand(unsigned idx, Value *value) { - getInstOperand(idx).set(value); - } - - operand_iterator operand_begin() { return operand_iterator(this, 0); } - operand_iterator operand_end() { - return operand_iterator(this, getNumOperands()); - } - - const_operand_iterator operand_begin() const { - return const_operand_iterator(this, 0); - } - const_operand_iterator operand_end() const { - return const_operand_iterator(this, getNumOperands()); - } - - ArrayRef getInstOperands() const { - return getOperandStorage().getInstOperands(); - } - MutableArrayRef getInstOperands() { - return getOperandStorage().getInstOperands(); - } - InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } - const InstOperand &getInstOperand(unsigned idx) const { - return getInstOperands()[idx]; - } - - // TODO: provide iterators for the lower and upper bound operands - // if the current access via getLowerBound(), getUpperBound() is too slow. - - /// Returns operands for the lower bound map. - operand_range getLowerBoundOperands(); - const_operand_range getLowerBoundOperands() const; - - /// Returns operands for the upper bound map. - operand_range getUpperBoundOperands(); - const_operand_range getUpperBoundOperands() const; - - //===--------------------------------------------------------------------===// - // Other - //===--------------------------------------------------------------------===// - - /// Return the context this operation is associated with. - MLIRContext *getContext() const { - return getInductionVar()->getType().getContext(); - } - - using Instruction::dump; - using Instruction::print; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const IROperandOwner *ptr) { - return ptr->getKind() == IROperandOwner::Kind::ForInst; - } - - /// Returns the induction variable for this loop. - Value *getInductionVar(); - const Value *getInductionVar() const { - return const_cast(this)->getInductionVar(); - } - - void destroy(); - -private: - // The Block for the body. By construction, this list always contains exactly - // one block. - BlockList body; - - // Affine map for the lower bound. - AffineMap lbMap; - // Affine map for the upper bound. The upper bound is exclusive. - AffineMap ubMap; - // Positive constant step. Since index is stored as an int64_t, we restrict - // step to the set of positive integers that int64_t can represent. - int64_t step; - - explicit ForInst(Location location, AffineMap lbMap, AffineMap ubMap, - int64_t step); - ~ForInst(); - - /// Returns the operand storage object. - detail::OperandStorage &getOperandStorage() { - return *getTrailingObjects(); - } - const detail::OperandStorage &getOperandStorage() const { - return *getTrailingObjects(); - } - - // This stuff is used by the TrailingObjects template. - friend llvm::TrailingObjects; -}; - -/// Returns if the provided value is the induction variable of a ForInst. -bool isForInductionVar(const Value *val); - -/// Returns the loop parent of an induction variable. If the provided value is -/// not an induction variable, then return nullptr. -ForInst *getForInductionVarOwner(Value *val); -const ForInst *getForInductionVarOwner(const Value *val); - -/// Extracts the induction variables from a list of ForInsts and returns them. -SmallVector extractForInductionVars(ArrayRef forInsts); - -/// AffineBound represents a lower or upper bound in the for instruction. -/// This class does not own the underlying operands. Instead, it refers -/// to the operands stored in the ForInst. Its life span should not exceed -/// that of the for instruction it refers to. -class AffineBound { -public: - const ForInst *getForInst() const { return &inst; } - AffineMap getMap() const { return map; } - - unsigned getNumOperands() const { return opEnd - opStart; } - const Value *getOperand(unsigned idx) const { - return inst.getOperand(opStart + idx); - } - const InstOperand &getInstOperand(unsigned idx) const { - return inst.getInstOperand(opStart + idx); - } - - using operand_iterator = ForInst::operand_iterator; - using operand_range = ForInst::operand_range; - - operand_iterator operand_begin() const { - // These are iterators over Value *. Not casting away const'ness would - // require the caller to use const Value *. - return operand_iterator(const_cast(&inst), opStart); - } - operand_iterator operand_end() const { - return operand_iterator(const_cast(&inst), opEnd); - } - - /// Returns an iterator on the underlying Value's (Value *). - operand_range getOperands() const { return {operand_begin(), operand_end()}; } - ArrayRef getInstOperands() const { - auto ops = inst.getInstOperands(); - return ArrayRef(ops.begin() + opStart, ops.begin() + opEnd); - } - -private: - // 'for' instruction that contains this bound. - const ForInst &inst; - // Start and end positions of this affine bound operands in the list of - // the containing 'for' instruction operands. - unsigned opStart, opEnd; - // Affine map for this bound. - AffineMap map; - - AffineBound(const ForInst &inst, unsigned opStart, unsigned opEnd, - AffineMap map) - : inst(inst), opStart(opStart), opEnd(opEnd), map(map) {} - - friend class ForInst; -}; } // end namespace mlir #endif // MLIR_IR_INSTRUCTIONS_H diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index c9ef9bf7cd61..2c62816b924c 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -68,6 +68,11 @@ public: operator bool() const { return value.getInstruction(); } + bool operator==(OpPointer rhs) const { + return value.getInstruction() == rhs.value.getInstruction(); + } + bool operator!=(OpPointer rhs) const { return !(*this == rhs); } + /// OpPointer can be implicitly converted to OpType*. /// Return `nullptr` if there is no associated OperationInst*. operator OpType *() { @@ -87,6 +92,9 @@ public: private: OpType value; + + // Allow access to value to enable constructing an empty ConstOpPointer. + friend class ConstOpPointer; }; /// This pointer represents a notional "const OperationInst*" but where the @@ -96,6 +104,7 @@ class ConstOpPointer { public: explicit ConstOpPointer() : value(OperationInst::getNull().value) {} explicit ConstOpPointer(OpType value) : value(value) {} + ConstOpPointer(OpPointer pointer) : value(pointer.value) {} const OpType &operator*() const { return value; } @@ -104,6 +113,11 @@ public: /// Return true if non-null. operator bool() const { return value.getInstruction(); } + bool operator==(ConstOpPointer rhs) const { + return value.getInstruction() == rhs.value.getInstruction(); + } + bool operator!=(ConstOpPointer rhs) const { return !(*this == rhs); } + /// ConstOpPointer can always be implicitly converted to const OpType*. /// Return `nullptr` if there is no associated OperationInst*. operator const OpType *() const { diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index d3a5d35427f5..4e7596498e78 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -90,7 +90,8 @@ public: virtual void printGenericOp(const OperationInst *op) = 0; /// Prints a block list. - virtual void printBlockList(const BlockList &blocks) = 0; + virtual void printBlockList(const BlockList &blocks, + bool printEntryBlockArgs = true) = 0; private: OpAsmPrinter(const OpAsmPrinter &) = delete; @@ -170,6 +171,9 @@ public: /// This parses... a comma! virtual bool parseComma() = 0; + /// This parses an equal(=) token! + virtual bool parseEqual() = 0; + /// Parse a type. virtual bool parseType(Type &result) = 0; @@ -203,9 +207,9 @@ public: } /// Parse a keyword. - bool parseKeyword(const char *keyword) { + bool parseKeyword(const char *keyword, const Twine &msg = "") { if (parseOptionalKeyword(keyword)) - return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'"); + return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'" + msg); return false; } @@ -315,6 +319,10 @@ public: /// operation's block lists after the operation is created. virtual bool parseBlockList() = 0; + /// Parses an argument for the entry block of the next block list to be + /// parsed. + virtual bool parseBlockListEntryBlockArgument(Type argType) = 0; + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 80cd21362ceb..871e78bbb24f 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -75,15 +75,14 @@ private: }; /// Subclasses of IROperandOwner can be the owner of an IROperand. In practice -/// this is the common base between Instruction and Instruction. +/// this is the common base between Instructions. class IROperandOwner { public: enum class Kind { OperationInst, - ForInst, /// These enums define ranges used for classof implementations. - INST_LAST = ForInst, + INST_LAST = OperationInst, }; Kind getKind() const { return locationAndKind.getInt(); } diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index e0cf3039f075..f3d9b9fe9fda 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -27,11 +27,12 @@ #include "mlir/Support/LLVM.h" namespace mlir { - class AffineMap; -class ForInst; +class AffineForOp; +template class ConstOpPointer; class Function; class FuncBuilder; +template class OpPointer; // Values that can be used to signal success/failure. This can be implicitly // converted to/from boolean values, with false representing success and true @@ -44,51 +45,54 @@ struct LLVM_NODISCARD UtilResult { /// Unrolls this for instruction completely if the trip count is known to be /// constant. Returns false otherwise. -bool loopUnrollFull(ForInst *forInst); +bool loopUnrollFull(OpPointer forOp); /// Unrolls this for instruction by the specified unroll factor. Returns false /// if the loop cannot be unrolled either due to restrictions or due to invalid /// unroll factors. -bool loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor); +bool loopUnrollByFactor(OpPointer forOp, uint64_t unrollFactor); /// Unrolls this loop by the specified unroll factor or its trip count, /// whichever is lower. -bool loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor); +bool loopUnrollUpToFactor(OpPointer forOp, uint64_t unrollFactor); /// Unrolls and jams this loop by the specified factor. Returns true if the loop /// is successfully unroll-jammed. -bool loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor); +bool loopUnrollJamByFactor(OpPointer forOp, + uint64_t unrollJamFactor); /// Unrolls and jams this loop by the specified factor or by the trip count (if /// constant), whichever is lower. -bool loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor); +bool loopUnrollJamUpToFactor(OpPointer forOp, + uint64_t unrollJamFactor); -/// Promotes the loop body of a ForInst to its containing block if the ForInst -/// was known to have a single iteration. Returns false otherwise. -bool promoteIfSingleIteration(ForInst *forInst); +/// Promotes the loop body of a AffineForOp to its containing block if the +/// AffineForOp was known to have a single iteration. Returns false otherwise. +bool promoteIfSingleIteration(OpPointer forOp); -/// Promotes all single iteration ForInst's in the Function, i.e., moves +/// Promotes all single iteration AffineForOp's in the Function, i.e., moves /// their body into the containing Block. void promoteSingleIterationLoops(Function *f); /// Returns the lower bound of the cleanup loop when unrolling a loop /// with the specified unroll factor. -AffineMap getCleanupLoopLowerBound(const ForInst &forInst, +AffineMap getCleanupLoopLowerBound(ConstOpPointer forOp, unsigned unrollFactor, FuncBuilder *builder); /// Returns the upper bound of an unrolled loop when unrolling with /// the specified trip count, stride, and unroll factor. -AffineMap getUnrolledLoopUpperBound(const ForInst &forInst, +AffineMap getUnrolledLoopUpperBound(ConstOpPointer forOp, unsigned unrollFactor, FuncBuilder *builder); /// Skew the instructions in the body of a 'for' instruction with the specified /// instruction-wise shifts. The shifts are with respect to the original /// execution order, and are multiplied by the loop 'step' before being applied. -UtilResult instBodySkew(ForInst *forInst, ArrayRef shifts, +UtilResult instBodySkew(OpPointer forOp, ArrayRef shifts, bool unrollPrologueEpilogue = false); /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. -UtilResult tileCodeGen(ArrayRef band, ArrayRef tileSizes); +UtilResult tileCodeGen(MutableArrayRef> band, + ArrayRef tileSizes); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 714086f22a7b..3269ac1fdc59 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -27,7 +27,8 @@ namespace mlir { -class ForInst; +class AffineForOp; +template class ConstOpPointer; class FunctionPass; class ModulePass; @@ -57,9 +58,10 @@ FunctionPass *createMaterializeVectorsPass(); /// factors supplied through other means. If -1 is passed as the unrollFactor /// and no callback is provided, anything passed from the command-line (if at /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor). -FunctionPass *createLoopUnrollPass( - int unrollFactor = -1, int unrollFull = -1, - const std::function &getUnrollFactor = nullptr); +FunctionPass * +createLoopUnrollPass(int unrollFactor = -1, int unrollFull = -1, + const std::function)> + &getUnrollFactor = nullptr); /// Creates a loop unroll jam pass to unroll jam by the specified factor. A /// factor of -1 lets the pass use the default factor or the one on the command diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 5c7260c9a583..169633cc106b 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -32,7 +32,7 @@ namespace mlir { -class ForInst; +class AffineForOp; class FuncBuilder; class Location; class Module; @@ -115,7 +115,7 @@ void createAffineComputationSlice( /// Folds the lower and upper bounds of a 'for' inst to constants if possible. /// Returns false if the folding happens for at least one bound, true otherwise. -bool constantFoldBounds(ForInst *forInst); +bool constantFoldBounds(OpPointer forInst); /// Replaces (potentially nested) function attributes in the operation "op" /// with those specified in "remappingTable". diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 5b29467fc443..f1693c8e449a 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -17,7 +17,10 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/OpImplementation.h" using namespace mlir; @@ -27,7 +30,445 @@ using namespace mlir; AffineOpsDialect::AffineOpsDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) { - addOperations(); + addOperations(); +} + +//===----------------------------------------------------------------------===// +// AffineForOp +//===----------------------------------------------------------------------===// + +void AffineForOp::build(Builder *builder, OperationState *result, + ArrayRef lbOperands, AffineMap lbMap, + ArrayRef ubOperands, AffineMap ubMap, + int64_t step) { + assert((!lbMap && lbOperands.empty()) || + lbOperands.size() == lbMap.getNumInputs() && + "lower bound operand count does not match the affine map"); + assert((!ubMap && ubOperands.empty()) || + ubOperands.size() == ubMap.getNumInputs() && + "upper bound operand count does not match the affine map"); + assert(step > 0 && "step has to be a positive integer constant"); + + // Add an attribute for the step. + result->addAttribute(getStepAttrName(), + builder->getIntegerAttr(builder->getIndexType(), step)); + + // Add the lower bound. + result->addAttribute(getLowerBoundAttrName(), + builder->getAffineMapAttr(lbMap)); + result->addOperands(lbOperands); + + // Add the upper bound. + result->addAttribute(getUpperBoundAttrName(), + builder->getAffineMapAttr(ubMap)); + result->addOperands(ubOperands); + + // Reserve a block list for the body. + result->reserveBlockLists(/*numReserved=*/1); + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); +} + +void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb, + int64_t ub, int64_t step) { + auto lbMap = AffineMap::getConstantMap(lb, builder->getContext()); + auto ubMap = AffineMap::getConstantMap(ub, builder->getContext()); + return build(builder, result, {}, lbMap, {}, ubMap, step); +} + +bool AffineForOp::verify() const { + const auto &bodyBlockList = getInstruction()->getBlockList(0); + + // The body block list must contain a single basic block. + if (bodyBlockList.empty() || + std::next(bodyBlockList.begin()) != bodyBlockList.end()) + return emitOpError("expected body block list to have a single block"); + + // Check that the body defines as single block argument for the induction + // variable. + const auto *body = getBody(); + if (body->getNumArguments() != 1 || + !body->getArgument(0)->getType().isIndex()) + return emitOpError("expected body to have a single index argument for the " + "induction variable"); + + // TODO: check that loop bounds are properly formed. + return false; +} + +/// Parse a for operation loop bounds. +static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { + // 'min' / 'max' prefixes are generally syntactic sugar, but are required if + // the map has multiple results. + bool failedToParsedMinMax = p->parseOptionalKeyword(isLower ? "max" : "min"); + + auto &builder = p->getBuilder(); + auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() + : AffineForOp::getUpperBoundAttrName(); + + // Parse ssa-id as identity map. + SmallVector boundOpInfos; + if (p->parseOperandList(boundOpInfos)) + return true; + + if (!boundOpInfos.empty()) { + // Check that only one operand was parsed. + if (boundOpInfos.size() > 1) + return p->emitError(p->getNameLoc(), + "expected only one loop bound operand"); + + // TODO: improve error message when SSA value is not an affine integer. + // Currently it is 'use of value ... expects different type than prior uses' + if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(), + result->operands)) + return true; + + // Create an identity map using symbol id. This representation is optimized + // for storage. Analysis passes may expand it into a multi-dimensional map + // if desired. + AffineMap map = builder.getSymbolIdentityMap(); + result->addAttribute(boundAttrName, builder.getAffineMapAttr(map)); + return false; + } + + Attribute boundAttr; + if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName.data(), + result->attributes)) + return true; + + // Parse full form - affine map followed by dim and symbol list. + if (auto affineMapAttr = boundAttr.dyn_cast()) { + unsigned currentNumOperands = result->operands.size(); + unsigned numDims; + if (parseDimAndSymbolList(p, result->operands, numDims)) + return true; + + auto map = affineMapAttr.getValue(); + if (map.getNumDims() != numDims) + return p->emitError( + p->getNameLoc(), + "dim operand count and integer set dim count must match"); + + unsigned numDimAndSymbolOperands = + result->operands.size() - currentNumOperands; + if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) + return p->emitError( + p->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // If the map has multiple results, make sure that we parsed the min/max + // prefix. + if (map.getNumResults() > 1 && failedToParsedMinMax) { + if (isLower) { + return p->emitError(p->getNameLoc(), + "lower loop bound affine map with multiple results " + "requires 'max' prefix"); + } + return p->emitError(p->getNameLoc(), + "upper loop bound affine map with multiple results " + "requires 'min' prefix"); + } + return false; + } + + // Parse custom assembly form. + if (auto integerAttr = boundAttr.dyn_cast()) { + result->attributes.pop_back(); + result->addAttribute( + boundAttrName, builder.getAffineMapAttr( + builder.getConstantAffineMap(integerAttr.getInt()))); + return false; + } + + return p->emitError( + p->getNameLoc(), + "expected valid affine map representation for loop bounds"); +} + +bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + // Parse the induction variable followed by '='. + if (parser->parseBlockListEntryBlockArgument(builder.getIndexType()) || + parser->parseEqual()) + return true; + + // Parse loop bounds. + if (parseBound(/*isLower=*/true, result, parser) || + parser->parseKeyword("to", " between bounds") || + parseBound(/*isLower=*/false, result, parser)) + return true; + + // Parse the optional loop step, we default to 1 if one is not present. + if (parser->parseOptionalKeyword("step")) { + result->addAttribute( + getStepAttrName(), + builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); + } else { + llvm::SMLoc stepLoc; + IntegerAttr stepAttr; + if (parser->getCurrentLocation(&stepLoc) || + parser->parseAttribute(stepAttr, builder.getIndexType(), + getStepAttrName().data(), result->attributes)) + return true; + + if (stepAttr.getValue().getSExtValue() < 0) + return parser->emitError( + stepLoc, + "expected step to be representable as a positive signed integer"); + } + + // Parse the body block list. + result->reserveBlockLists(/*numReserved=*/1); + if (parser->parseBlockList()) + return true; + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); + return false; +} + +static void printBound(AffineBound bound, const char *prefix, OpAsmPrinter *p) { + AffineMap map = bound.getMap(); + + // Check if this bound should be printed using custom assembly form. + // The decision to restrict printing custom assembly form to trivial cases + // comes from the will to roundtrip MLIR binary -> text -> binary in a + // lossless way. + // Therefore, custom assembly form parsing and printing is only supported for + // zero-operand constant maps and single symbol operand identity maps. + if (map.getNumResults() == 1) { + AffineExpr expr = map.getResult(0); + + // Print constant bound. + if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { + if (auto constExpr = expr.dyn_cast()) { + *p << constExpr.getValue(); + return; + } + } + + // Print bound that consists of a single SSA symbol if the map is over a + // single symbol. + if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { + if (auto symExpr = expr.dyn_cast()) { + p->printOperand(bound.getOperand(0)); + return; + } + } + } else { + // Map has multiple results. Print 'min' or 'max' prefix. + *p << prefix << ' '; + } + + // Print the map and its operands. + p->printAffineMap(map); + printDimAndSymbolList(bound.operand_begin(), bound.operand_end(), + map.getNumDims(), p); +} + +void AffineForOp::print(OpAsmPrinter *p) const { + *p << "for "; + p->printOperand(getBody()->getArgument(0)); + *p << " = "; + printBound(getLowerBound(), "max", p); + *p << " to "; + printBound(getUpperBound(), "min", p); + + if (getStep() != 1) + *p << " step " << getStep(); + p->printBlockList(getInstruction()->getBlockList(0), + /*printEntryBlockArgs=*/false); +} + +Block *AffineForOp::createBody() { + auto &bodyBlockList = getBlockList(); + assert(bodyBlockList.empty() && "expected no existing body blocks"); + + // Create a new block for the body, and add an argument for the induction + // variable. + Block *body = new Block(); + body->addArgument(IndexType::get(getInstruction()->getContext())); + bodyBlockList.push_back(body); + return body; +} + +const AffineBound AffineForOp::getLowerBound() const { + auto lbMap = getLowerBoundMap(); + return AffineBound(ConstOpPointer(*this), 0, + lbMap.getNumInputs(), lbMap); +} + +const AffineBound AffineForOp::getUpperBound() const { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + return AffineBound(ConstOpPointer(*this), lbMap.getNumInputs(), + getNumOperands(), ubMap); +} + +void AffineForOp::setLowerBound(ArrayRef lbOperands, AffineMap map) { + assert(lbOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector newOperands(lbOperands.begin(), lbOperands.end()); + + auto ubOperands = getUpperBoundOperands(); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getInstruction()->setOperands(newOperands); + + setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBound(ArrayRef ubOperands, AffineMap map) { + assert(ubOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector newOperands(getLowerBoundOperands()); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getInstruction()->setOperands(newOperands); + + setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setLowerBoundMap(AffineMap map) { + auto lbMap = getLowerBoundMap(); + assert(lbMap.getNumDims() == map.getNumDims() && + lbMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)lbMap; + setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBoundMap(AffineMap map) { + auto ubMap = getUpperBoundMap(); + assert(ubMap.getNumDims() == map.getNumDims() && + ubMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)ubMap; + setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +bool AffineForOp::hasConstantLowerBound() const { + return getLowerBoundMap().isSingleConstant(); +} + +bool AffineForOp::hasConstantUpperBound() const { + return getUpperBoundMap().isSingleConstant(); +} + +int64_t AffineForOp::getConstantLowerBound() const { + return getLowerBoundMap().getSingleConstantResult(); +} + +int64_t AffineForOp::getConstantUpperBound() const { + return getUpperBoundMap().getSingleConstantResult(); +} + +void AffineForOp::setConstantLowerBound(int64_t value) { + setLowerBound( + {}, AffineMap::getConstantMap(value, getInstruction()->getContext())); +} + +void AffineForOp::setConstantUpperBound(int64_t value) { + setUpperBound( + {}, AffineMap::getConstantMap(value, getInstruction()->getContext())); +} + +AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +AffineForOp::const_operand_range AffineForOp::getLowerBoundOperands() const { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +AffineForOp::const_operand_range AffineForOp::getUpperBoundOperands() const { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +bool AffineForOp::matchingBoundOperandList() const { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + if (lbMap.getNumDims() != ubMap.getNumDims() || + lbMap.getNumSymbols() != ubMap.getNumSymbols()) + return false; + + unsigned numOperands = lbMap.getNumInputs(); + for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { + // Compare Value *'s. + if (getOperand(i) != getOperand(numOperands + i)) + return false; + } + return true; +} + +void AffineForOp::walkOps(std::function callback) { + struct Walker : public InstWalker { + std::function const &callback; + Walker(std::function const &callback) + : callback(callback) {} + + void visitOperationInst(OperationInst *opInst) { callback(opInst); } + }; + + Walker w(callback); + w.walk(getInstruction()); +} + +void AffineForOp::walkOpsPostOrder( + std::function callback) { + struct Walker : public InstWalker { + std::function const &callback; + Walker(std::function const &callback) + : callback(callback) {} + + void visitOperationInst(OperationInst *opInst) { callback(opInst); } + }; + + Walker v(callback); + v.walkPostOrder(getInstruction()); +} + +/// Returns the induction variable for this loop. +Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); } + +/// Returns if the provided value is the induction variable of a AffineForOp. +bool mlir::isForInductionVar(const Value *val) { + return getForInductionVarOwner(val) != nullptr; +} + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +OpPointer mlir::getForInductionVarOwner(Value *val) { + const BlockArgument *ivArg = dyn_cast(val); + if (!ivArg || !ivArg->getOwner()) + return OpPointer(); + auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst(); + if (!containingInst) + return OpPointer(); + return cast(containingInst)->dyn_cast(); +} +ConstOpPointer mlir::getForInductionVarOwner(const Value *val) { + auto nonConstOwner = getForInductionVarOwner(const_cast(val)); + return ConstOpPointer(nonConstOwner); +} + +/// Extracts the induction variables from a list of AffineForOps and returns +/// them. +SmallVector mlir::extractForInductionVars( + MutableArrayRef> forInsts) { + SmallVector results; + for (auto forInst : forInsts) + results.push_back(forInst->getInductionVar()); + return results; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 0153546a4c6a..d2366f1ce81a 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -21,12 +21,14 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" @@ -519,7 +521,7 @@ void mlir::getReachableAffineApplyOps( State &state = worklist.back(); auto *opInst = state.value->getDefiningInst(); // Note: getDefiningInst will return nullptr if the operand is not an - // OperationInst (i.e. ForInst), which is a terminator for the search. + // OperationInst (i.e. AffineForOp), which is a terminator for the search. if (opInst == nullptr || !opInst->isa()) { worklist.pop_back(); continue; @@ -546,21 +548,21 @@ void mlir::getReachableAffineApplyOps( } // Builds a system of constraints with dimensional identifiers corresponding to -// the loop IVs of the forInsts appearing in that order. Any symbols founds in +// the loop IVs of the forOps appearing in that order. Any symbols founds in // the bound operands are added as symbols in the system. Returns false for the // yet unimplemented cases. // TODO(andydavis,bondhugula) Handle non-unit steps through local variables or // stride information in FlatAffineConstraints. (For eg., by using iv - lb % // step = 0 and/or by introducing a method in FlatAffineConstraints // setExprStride(ArrayRef expr, int64_t stride) -bool mlir::getIndexSet(ArrayRef forInsts, +bool mlir::getIndexSet(MutableArrayRef> forOps, FlatAffineConstraints *domain) { - auto indices = extractForInductionVars(forInsts); + auto indices = extractForInductionVars(forOps); // Reset while associated Values in 'indices' to the domain. - domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); - for (auto *forInst : forInsts) { - // Add constraints from forInst's bounds. - if (!domain->addForInstDomain(*forInst)) + domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); + for (auto forOp : forOps) { + // Add constraints from forOp's bounds. + if (!domain->addAffineForOpDomain(forOp)) return false; } return true; @@ -576,7 +578,7 @@ static bool getInstIndexSet(const Instruction *inst, FlatAffineConstraints *indexSet) { // TODO(andydavis) Extend this to gather enclosing IfInsts and consider // factoring it out into a utility function. - SmallVector loops; + SmallVector, 4> loops; getLoopIVs(*inst, &loops); return getIndexSet(loops, indexSet); } @@ -998,9 +1000,9 @@ static const Block *getCommonBlock(const MemRefAccess &srcAccess, return block; } auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1); - auto *forInst = getForInductionVarOwner(commonForValue); - assert(forInst && "commonForValue was not an induction variable"); - return forInst->getBody(); + auto forOp = getForInductionVarOwner(commonForValue); + assert(forOp && "commonForValue was not an induction variable"); + return forOp->getBody(); } // Returns true if the ancestor operation instruction of 'srcAccess' appears @@ -1195,7 +1197,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // until operands of the AffineValueMap are loop IVs or symbols. // *) Build iteration domain constraints for each access. Iteration domain // constraints are pairs of inequality contraints representing the -// upper/lower loop bounds for each ForInst in the loop nest associated +// upper/lower loop bounds for each AffineForOp in the loop nest associated // with each access. // *) Build dimension and symbol position maps for each access, which map // Values from access functions and iteration domains to their position diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 5e7f8e3243c2..c794899d3e17 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -20,6 +20,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineStructures.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" @@ -1247,22 +1248,23 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { numSymbols = newSymbolCount; } -bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { +bool FlatAffineConstraints::addAffineForOpDomain( + ConstOpPointer forOp) { unsigned pos; // Pre-condition for this method. - if (!findId(*forInst.getInductionVar(), &pos)) { + if (!findId(*forOp->getInductionVar(), &pos)) { assert(0 && "Value not found"); return false; } - if (forInst.getStep() != 1) + if (forOp->getStep() != 1) LLVM_DEBUG(llvm::dbgs() << "Domain conservative: non-unit stride not handled\n"); // Adds a lower or upper bound when the bounds aren't constant. auto addLowerOrUpperBound = [&](bool lower) -> bool { - auto operands = lower ? forInst.getLowerBoundOperands() - : forInst.getUpperBoundOperands(); + auto operands = + lower ? forOp->getLowerBoundOperands() : forOp->getUpperBoundOperands(); for (const auto &operand : operands) { unsigned loc; if (!findId(*operand, &loc)) { @@ -1291,7 +1293,7 @@ bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { } auto boundMap = - lower ? forInst.getLowerBoundMap() : forInst.getUpperBoundMap(); + lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap(); FlatAffineConstraints localVarCst; std::vector> flatExprs; @@ -1321,16 +1323,16 @@ bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { return true; }; - if (forInst.hasConstantLowerBound()) { - addConstantLowerBound(pos, forInst.getConstantLowerBound()); + if (forOp->hasConstantLowerBound()) { + addConstantLowerBound(pos, forOp->getConstantLowerBound()); } else { // Non-constant lower bound case. if (!addLowerOrUpperBound(/*lower=*/true)) return false; } - if (forInst.hasConstantUpperBound()) { - addConstantUpperBound(pos, forInst.getConstantUpperBound() - 1); + if (forOp->hasConstantUpperBound()) { + addConstantUpperBound(pos, forOp->getConstantUpperBound() - 1); return true; } // Non-constant upper bound case. diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 7d88a3d9b9fb..249776d42c98 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -43,27 +43,27 @@ using namespace mlir; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count /// expression is simplified before returning. -AffineExpr mlir::getTripCountExpr(const ForInst &forInst) { +AffineExpr mlir::getTripCountExpr(ConstOpPointer forOp) { // upper_bound - lower_bound int64_t loopSpan; - int64_t step = forInst.getStep(); - auto *context = forInst.getContext(); + int64_t step = forOp->getStep(); + auto *context = forOp->getInstruction()->getContext(); - if (forInst.hasConstantBounds()) { - int64_t lb = forInst.getConstantLowerBound(); - int64_t ub = forInst.getConstantUpperBound(); + if (forOp->hasConstantBounds()) { + int64_t lb = forOp->getConstantLowerBound(); + int64_t ub = forOp->getConstantUpperBound(); loopSpan = ub - lb; } else { - auto lbMap = forInst.getLowerBoundMap(); - auto ubMap = forInst.getUpperBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); + auto ubMap = forOp->getUpperBoundMap(); // TODO(bondhugula): handle max/min of multiple expressions. if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1) return nullptr; // TODO(bondhugula): handle bounds with different operands. // Bounds have different operands, unhandled for now. - if (!forInst.matchingBoundOperandList()) + if (!forOp->matchingBoundOperandList()) return nullptr; // ub_expr - lb_expr @@ -89,8 +89,9 @@ AffineExpr mlir::getTripCountExpr(const ForInst &forInst) { /// Returns the trip count of the loop if it's a constant, None otherwise. This /// method uses affine expression analysis (in turn using getTripCount) and is /// able to determine constant trip count in non-trivial cases. -llvm::Optional mlir::getConstantTripCount(const ForInst &forInst) { - auto tripCountExpr = getTripCountExpr(forInst); +llvm::Optional +mlir::getConstantTripCount(ConstOpPointer forOp) { + auto tripCountExpr = getTripCountExpr(forOp); if (!tripCountExpr) return None; @@ -104,8 +105,8 @@ llvm::Optional mlir::getConstantTripCount(const ForInst &forInst) { /// Returns the greatest known integral divisor of the trip count. Affine /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. -uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) { - auto tripCountExpr = getTripCountExpr(forInst); +uint64_t mlir::getLargestDivisorOfTripCount(ConstOpPointer forOp) { + auto tripCountExpr = getTripCountExpr(forOp); if (!tripCountExpr) return 1; @@ -126,7 +127,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) { } bool mlir::isAccessInvariant(const Value &iv, const Value &index) { - assert(isForInductionVar(&iv) && "iv must be a ForInst"); + assert(isForInductionVar(&iv) && "iv must be a AffineForOp"); assert(index.getType().isa() && "index must be of IndexType"); SmallVector affineApplyOps; getReachableAffineApplyOps({const_cast(&index)}, affineApplyOps); @@ -163,7 +164,7 @@ mlir::getInvariantAccesses(const Value &iv, } /// Given: -/// 1. an induction variable `iv` of type ForInst; +/// 1. an induction variable `iv` of type AffineForOp; /// 2. a `memoryOp` of type const LoadOp& or const StoreOp&; /// 3. the index of the `fastestVaryingDim` along which to check; /// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access @@ -231,17 +232,18 @@ static bool isVectorTransferReadOrWrite(const Instruction &inst) { } using VectorizableInstFun = - std::function; + std::function, const OperationInst &)>; -static bool isVectorizableLoopWithCond(const ForInst &loop, +static bool isVectorizableLoopWithCond(ConstOpPointer loop, VectorizableInstFun isVectorizableInst) { - if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) { + auto *forInst = const_cast(loop->getInstruction()); + if (!matcher::isParallelLoop(*forInst) && + !matcher::isReductionLoop(*forInst)) { return false; } // No vectorization across conditionals for now. auto conditionals = matcher::If(); - auto *forInst = const_cast(&loop); SmallVector conditionalsMatched; conditionals.match(forInst, &conditionalsMatched); if (!conditionalsMatched.empty()) { @@ -251,7 +253,8 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, // No vectorization across unknown regions. auto regions = matcher::Op([](const Instruction &inst) -> bool { auto &opInst = cast(inst); - return opInst.getNumBlockLists() != 0 && !opInst.isa(); + return opInst.getNumBlockLists() != 0 && + !(opInst.isa() || opInst.isa()); }); SmallVector regionsMatched; regions.match(forInst, ®ionsMatched); @@ -288,23 +291,25 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, } bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( - const ForInst &loop, unsigned fastestVaryingDim) { - VectorizableInstFun fun( - [fastestVaryingDim](const ForInst &loop, const OperationInst &op) { - auto load = op.dyn_cast(); - auto store = op.dyn_cast(); - return load ? isContiguousAccess(*loop.getInductionVar(), *load, - fastestVaryingDim) - : isContiguousAccess(*loop.getInductionVar(), *store, - fastestVaryingDim); - }); + ConstOpPointer loop, unsigned fastestVaryingDim) { + VectorizableInstFun fun([fastestVaryingDim](ConstOpPointer loop, + const OperationInst &op) { + auto load = op.dyn_cast(); + auto store = op.dyn_cast(); + return load ? isContiguousAccess(*loop->getInductionVar(), *load, + fastestVaryingDim) + : isContiguousAccess(*loop->getInductionVar(), *store, + fastestVaryingDim); + }); return isVectorizableLoopWithCond(loop, fun); } -bool mlir::isVectorizableLoop(const ForInst &loop) { +bool mlir::isVectorizableLoop(ConstOpPointer loop) { VectorizableInstFun fun( // TODO: implement me - [](const ForInst &loop, const OperationInst &op) { return true; }); + [](ConstOpPointer loop, const OperationInst &op) { + return true; + }); return isVectorizableLoopWithCond(loop, fun); } @@ -313,9 +318,9 @@ bool mlir::isVectorizableLoop(const ForInst &loop) { /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. -bool mlir::isInstwiseShiftValid(const ForInst &forInst, +bool mlir::isInstwiseShiftValid(ConstOpPointer forOp, ArrayRef shifts) { - auto *forBody = forInst.getBody(); + auto *forBody = forOp->getBody(); assert(shifts.size() == forBody->getInstructions().size()); unsigned s = 0; for (const auto &inst : *forBody) { @@ -325,7 +330,7 @@ bool mlir::isInstwiseShiftValid(const ForInst &forInst, for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) { const Value *result = opInst->getResult(i); for (const InstOperand &use : result->getUses()) { - // If an ancestor instruction doesn't lie in the block of forInst, + // If an ancestor instruction doesn't lie in the block of forOp, // there is no shift to check. This is a naive way. If performance // becomes an issue, a map can be used to store 'shifts' - to look up // the shift for a instruction in constant time. diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 46bf5ad0b975..214b4ce403c4 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -115,6 +115,10 @@ void NestedPattern::matchOne(Instruction *inst, } } +static bool isAffineForOp(const Instruction &inst) { + return cast(inst).isa(); +} + static bool isAffineIfOp(const Instruction &inst) { return isa(inst) && cast(inst).isa(); @@ -147,28 +151,34 @@ NestedPattern If(FilterFunctionType filter, ArrayRef nested) { } NestedPattern For(NestedPattern child) { - return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, child, isAffineForOp); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::For, child, filter); + return NestedPattern(Instruction::Kind::OperationInst, child, + [=](const Instruction &inst) { + return isAffineForOp(inst) && filter(inst); + }); } NestedPattern For(ArrayRef nested) { - return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineForOp); } NestedPattern For(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(Instruction::Kind::For, nested, filter); + return NestedPattern(Instruction::Kind::OperationInst, nested, + [=](const Instruction &inst) { + return isAffineForOp(inst) && filter(inst); + }); } // TODO(ntv): parallel annotation on loops. bool isParallelLoop(const Instruction &inst) { - const auto *loop = cast(&inst); - return (void *)loop || true; // loop->isParallel(); + auto loop = cast(inst).cast(); + return loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. bool isReductionLoop(const Instruction &inst) { - const auto *loop = cast(&inst); - return (void *)loop || true; // loop->isReduction(); + auto loop = cast(inst).cast(); + return loop || true; // loop->isReduction(); }; bool isLoadOrStore(const Instruction &inst) { diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index d16a7fcb1b31..4025af936f3f 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -20,6 +20,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" @@ -52,7 +53,16 @@ void mlir::getForwardSlice(Instruction *inst, return; } - if (auto *opInst = dyn_cast(inst)) { + auto *opInst = cast(inst); + if (auto forOp = opInst->dyn_cast()) { + for (auto &u : forOp->getInductionVar()->getUses()) { + auto *ownerInst = u.getOwner(); + if (forwardSlice->count(ownerInst) == 0) { + getForwardSlice(ownerInst, forwardSlice, filter, + /*topLevel=*/false); + } + } + } else { assert(opInst->getNumResults() <= 1 && "NYI: multiple results"); if (opInst->getNumResults() > 0) { for (auto &u : opInst->getResult(0)->getUses()) { @@ -63,16 +73,6 @@ void mlir::getForwardSlice(Instruction *inst, } } } - } else if (auto *forInst = dyn_cast(inst)) { - for (auto &u : forInst->getInductionVar()->getUses()) { - auto *ownerInst = u.getOwner(); - if (forwardSlice->count(ownerInst) == 0) { - getForwardSlice(ownerInst, forwardSlice, filter, - /*topLevel=*/false); - } - } - } else { - assert(false && "NYI slicing case"); } // At the top level we reverse to get back the actual topological order. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 0e77d4d9084b..4b8afd9a6205 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -38,15 +38,17 @@ using namespace mlir; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. void mlir::getLoopIVs(const Instruction &inst, - SmallVectorImpl *loops) { + SmallVectorImpl> *loops) { auto *currInst = inst.getParentInst(); - ForInst *currForInst; + OpPointer currAffineForOp; // Traverse up the hierarchy collecing all 'for' instruction while skipping // over 'if' instructions. - while (currInst && ((currForInst = dyn_cast(currInst)) || - cast(currInst)->isa())) { - if (currForInst) - loops->push_back(currForInst); + while (currInst && + ((currAffineForOp = + cast(currInst)->dyn_cast()) || + cast(currInst)->isa())) { + if (currAffineForOp) + loops->push_back(currAffineForOp); currInst = currInst->getParentInst(); } std::reverse(loops->begin(), loops->end()); @@ -148,7 +150,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, if (rank == 0) { // A rank 0 memref has a 0-d region. - SmallVector ivs; + SmallVector, 4> ivs; getLoopIVs(*opInst, &ivs); SmallVector regionSymbols = extractForInductionVars(ivs); @@ -174,12 +176,12 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, unsigned numSymbols = accessMap.getNumSymbols(); // Add inequalties for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { - if (auto *loop = getForInductionVarOwner(accessValueMap.getOperand(i))) { + if (auto loop = getForInductionVarOwner(accessValueMap.getOperand(i))) { // Note that regionCst can now have more dimensions than accessMap if the // bounds expressions involve outer loops or other symbols. // TODO(bondhugula): rewrite this to use getInstIndexSet; this way // conditionals will be handled when the latter supports it. - if (!regionCst->addForInstDomain(*loop)) + if (!regionCst->addAffineForOpDomain(loop)) return false; } else { // Has to be a valid symbol. @@ -203,14 +205,14 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which // this memref region is symbolic. - SmallVector outerIVs; + SmallVector, 4> outerIVs; getLoopIVs(*opInst, &outerIVs); assert(loopDepth <= outerIVs.size() && "invalid loop depth"); outerIVs.resize(loopDepth); for (auto *operand : accessValueMap.getOperands()) { - ForInst *iv; + OpPointer iv; if ((iv = getForInductionVarOwner(operand)) && - std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) { + llvm::is_contained(outerIVs, iv) == false) { regionCst->projectOut(operand); } } @@ -357,8 +359,10 @@ static Instruction *getInstAtPosition(ArrayRef positions, } if (level == positions.size() - 1) return &inst; - if (auto *childForInst = dyn_cast(&inst)) - return getInstAtPosition(positions, level + 1, childForInst->getBody()); + if (auto childAffineForOp = + cast(inst).dyn_cast()) + return getInstAtPosition(positions, level + 1, + childAffineForOp->getBody()); for (auto &blockList : cast(&inst)->getBlockLists()) { for (auto &b : blockList) @@ -385,12 +389,12 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, return false; } // Get loop nest surrounding src operation. - SmallVector srcLoopIVs; + SmallVector, 4> srcLoopIVs; getLoopIVs(*srcAccess.opInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Get loop nest surrounding dst operation. - SmallVector dstLoopIVs; + SmallVector, 4> dstLoopIVs; getLoopIVs(*dstAccess.opInst, &dstLoopIVs); unsigned numDstLoopIVs = dstLoopIVs.size(); if (dstLoopDepth > numDstLoopIVs) { @@ -437,38 +441,41 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, // solution. // TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project // out loop IVs we don't care about and produce smaller slice. -ForInst *mlir::insertBackwardComputationSlice( +OpPointer mlir::insertBackwardComputationSlice( OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState) { // Get loop nest surrounding src operation. - SmallVector srcLoopIVs; + SmallVector, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Get loop nest surrounding dst operation. - SmallVector dstLoopIVs; + SmallVector, 4> dstLoopIVs; getLoopIVs(*dstOpInst, &dstLoopIVs); unsigned dstLoopIVsSize = dstLoopIVs.size(); if (dstLoopDepth > dstLoopIVsSize) { dstOpInst->emitError("invalid destination loop depth"); - return nullptr; + return OpPointer(); } // Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'. SmallVector positions; // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d. - findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions); + findInstPosition(srcOpInst, srcLoopIVs[0]->getInstruction()->getBlock(), + &positions); // Clone src loop nest and insert it a the beginning of the instruction block // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. - auto *dstForInst = dstLoopIVs[dstLoopDepth - 1]; - FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin()); - auto *sliceLoopNest = cast(b.clone(*srcLoopIVs[0])); + auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; + FuncBuilder b(dstAffineForOp->getBody(), dstAffineForOp->getBody()->begin()); + auto sliceLoopNest = + cast(b.clone(*srcLoopIVs[0]->getInstruction())) + ->cast(); Instruction *sliceInst = getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); // Get loop nest surrounding 'sliceInst'. - SmallVector sliceSurroundingLoops; + SmallVector, 4> sliceSurroundingLoops; getLoopIVs(*sliceInst, &sliceSurroundingLoops); // Sanity check. @@ -481,11 +488,11 @@ ForInst *mlir::insertBackwardComputationSlice( // Update loop bounds for loops in 'sliceLoopNest'. for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - auto *forInst = sliceSurroundingLoops[dstLoopDepth + i]; + auto forOp = sliceSurroundingLoops[dstLoopDepth + i]; if (AffineMap lbMap = sliceState->lbs[i]) - forInst->setLowerBound(sliceState->lbOperands[i], lbMap); + forOp->setLowerBound(sliceState->lbOperands[i], lbMap); if (AffineMap ubMap = sliceState->ubs[i]) - forInst->setUpperBound(sliceState->ubOperands[i], ubMap); + forOp->setUpperBound(sliceState->ubOperands[i], ubMap); } return sliceLoopNest; } @@ -520,7 +527,7 @@ unsigned mlir::getNestingDepth(const Instruction &stmt) { const Instruction *currInst = &stmt; unsigned depth = 0; while ((currInst = currInst->getParentInst())) { - if (isa(currInst)) + if (cast(currInst)->isa()) depth++; } return depth; @@ -530,14 +537,14 @@ unsigned mlir::getNestingDepth(const Instruction &stmt) { /// where each lists loops from outer-most to inner-most in loop nest. unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A, const Instruction &B) { - SmallVector loopsA, loopsB; + SmallVector, 4> loopsA, loopsB; getLoopIVs(A, &loopsA); getLoopIVs(B, &loopsB); unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); unsigned numCommonLoops = 0; for (unsigned i = 0; i < minNumLoops; ++i) { - if (loopsA[i] != loopsB[i]) + if (loopsA[i]->getInstruction() != loopsB[i]->getInstruction()) break; ++numCommonLoops; } @@ -571,13 +578,14 @@ static Optional getRegionSize(const MemRefRegion ®ion) { return getMemRefEltSizeInBytes(memRefType) * numElements.getValue(); } -Optional mlir::getMemoryFootprintBytes(const ForInst &forInst, - int memorySpace) { +Optional +mlir::getMemoryFootprintBytes(ConstOpPointer forOp, + int memorySpace) { std::vector> regions; // Walk this 'for' instruction to gather all memory regions. bool error = false; - const_cast(&forInst)->walkOps([&](OperationInst *opInst) { + const_cast(*forOp).walkOps([&](OperationInst *opInst) { if (!opInst->isa() && !opInst->isa()) { // Neither load nor a store op. return; diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 125020e92a35..4865cb03bb4d 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -16,10 +16,12 @@ // ============================================================================= #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" @@ -105,7 +107,7 @@ Optional> mlir::shapeRatio(VectorType superVectorType, static AffineMap makePermutationMap( MLIRContext *context, llvm::iterator_range indices, - const DenseMap &enclosingLoopToVectorDim) { + const DenseMap &enclosingLoopToVectorDim) { using functional::makePtrDynCaster; using functional::map; auto unwrappedIndices = map(makePtrDynCaster(), indices); @@ -113,8 +115,9 @@ static AffineMap makePermutationMap( getAffineConstantExpr(0, context)); for (auto kvp : enclosingLoopToVectorDim) { assert(kvp.second < perm.size()); - auto invariants = - getInvariantAccesses(*kvp.first->getInductionVar(), unwrappedIndices); + auto invariants = getInvariantAccesses( + *cast(kvp.first)->cast()->getInductionVar(), + unwrappedIndices); unsigned numIndices = unwrappedIndices.size(); unsigned countInvariantIndices = 0; for (unsigned dim = 0; dim < numIndices; ++dim) { @@ -139,30 +142,30 @@ static AffineMap makePermutationMap( /// TODO(ntv): could also be implemented as a collect parents followed by a /// filter and made available outside this file. template -static SetVector getParentsOfType(Instruction *inst) { - SetVector res; +static SetVector getParentsOfType(Instruction *inst) { + SetVector res; auto *current = inst; while (auto *parent = current->getParentInst()) { - auto *typedParent = dyn_cast(parent); - if (typedParent) { - assert(res.count(typedParent) == 0 && "Already inserted"); - res.insert(typedParent); + if (auto typedParent = + cast(parent)->template dyn_cast()) { + assert(res.count(cast(parent)) == 0 && "Already inserted"); + res.insert(cast(parent)); } current = parent; } return res; } -/// Returns the enclosing ForInst, from closest to farthest. -static SetVector getEnclosingforInsts(Instruction *inst) { - return getParentsOfType(inst); +/// Returns the enclosing AffineForOp, from closest to farthest. +static SetVector getEnclosingforOps(Instruction *inst) { + return getParentsOfType(inst); } -AffineMap -mlir::makePermutationMap(OperationInst *opInst, - const DenseMap &loopToVectorDim) { - DenseMap enclosingLoopToVectorDim; - auto enclosingLoops = getEnclosingforInsts(opInst); +AffineMap mlir::makePermutationMap( + OperationInst *opInst, + const DenseMap &loopToVectorDim) { + DenseMap enclosingLoopToVectorDim; + auto enclosingLoops = getEnclosingforOps(opInst); for (auto *forInst : enclosingLoops) { auto it = loopToVectorDim.find(forInst); if (it != loopToVectorDim.end()) { diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 474eeb2a28e3..a69831053adb 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -72,7 +72,6 @@ public: bool verify(); bool verifyBlock(const Block &block, bool isTopLevel); bool verifyOperation(const OperationInst &op); - bool verifyForInst(const ForInst &forInst); bool verifyDominance(const Block &block); bool verifyInstDominance(const Instruction &inst); @@ -175,10 +174,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { if (verifyOperation(cast(inst))) return true; break; - case Instruction::Kind::For: - if (verifyForInst(cast(inst))) - return true; - break; } } @@ -240,11 +235,6 @@ bool FuncVerifier::verifyOperation(const OperationInst &op) { return false; } -bool FuncVerifier::verifyForInst(const ForInst &forInst) { - // TODO: check that loop bounds are properly formed. - return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false); -} - bool FuncVerifier::verifyDominance(const Block &block) { for (auto &inst : block) { // Check that all operands on the instruction are ok. @@ -262,10 +252,6 @@ bool FuncVerifier::verifyDominance(const Block &block) { return true; break; } - case Instruction::Kind::For: - if (verifyDominance(*cast(inst).getBody())) - return true; - break; } } return false; diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index dc85c5ed6824..f4d5d36d25b0 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -21,12 +21,14 @@ #include "llvm/Support/raw_ostream.h" #include "mlir-c/Core.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/EDSC/MLIREmitter.h" #include "mlir/EDSC/Types.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/Value.h" #include "mlir/StandardOps/StandardOps.h" @@ -133,8 +135,8 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) { inst->print(os); return; } - if (auto *forInst = getForInductionVarOwner(&v)) { - forInst->print(os); + if (auto forInst = getForInductionVarOwner(&v)) { + forInst->getInstruction()->print(os); } else { os << "unknown_ssa_value"; } @@ -300,7 +302,9 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { exprs[1]->getDefiningInst()->cast()->getValue(); auto step = exprs[2]->getDefiningInst()->cast()->getValue(); - res = builder->createFor(location, lb, ub, step)->getInductionVar(); + auto forOp = builder->create(location, lb, ub, step); + forOp->createBody(); + res = forOp->getInductionVar(); } } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index cb4c1f0edcee..0fb18fa0004d 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -130,21 +130,8 @@ private: void recordTypeReference(Type ty) { usedTypes.insert(ty); } - // Return true if this map could be printed using the custom assembly form. - static bool hasCustomForm(AffineMap boundMap) { - if (boundMap.isSingleConstant()) - return true; - - // Check if the affine map is single dim id or single symbol identity - - // (i)->(i) or ()[s]->(i) - return boundMap.getNumInputs() == 1 && boundMap.getNumResults() == 1 && - (boundMap.getResult(0).isa() || - boundMap.getResult(0).isa()); - } - // Visit functions. void visitInstruction(const Instruction *inst); - void visitForInst(const ForInst *forInst); void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -196,16 +183,6 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitForInst(const ForInst *forInst) { - AffineMap lbMap = forInst->getLowerBoundMap(); - if (!hasCustomForm(lbMap)) - recordAffineMapReference(lbMap); - - AffineMap ubMap = forInst->getUpperBoundMap(); - if (!hasCustomForm(ubMap)) - recordAffineMapReference(ubMap); -} - void ModuleState::visitOperationInst(const OperationInst *op) { // Visit all the types used in the operation. for (auto *operand : op->getOperands()) @@ -220,8 +197,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) { void ModuleState::visitInstruction(const Instruction *inst) { switch (inst->getKind()) { - case Instruction::Kind::For: - return visitForInst(cast(inst)); case Instruction::Kind::OperationInst: return visitOperationInst(cast(inst)); } @@ -1069,7 +1044,6 @@ public: // Methods to print instructions. void print(const Instruction *inst); void print(const OperationInst *inst); - void print(const ForInst *inst); void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); @@ -1117,10 +1091,8 @@ public: unsigned index) override; /// Print a block list. - void printBlockList(const BlockList &blocks) override { - printBlockList(blocks, /*printEntryBlockArgs=*/true); - } - void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { + void printBlockList(const BlockList &blocks, + bool printEntryBlockArgs) override { os << " {\n"; if (!blocks.empty()) { auto *entryBlock = &blocks.front(); @@ -1132,10 +1104,6 @@ public: os.indent(currentIndent) << "}"; } - // Print if and loop bounds. - void printDimAndSymbolList(ArrayRef ops, unsigned numDims); - void printBound(AffineBound bound, const char *prefix); - // Number of spaces used for indenting nested instructions. const static unsigned indentWidth = 2; @@ -1205,10 +1173,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { numberValuesInBlock(block); break; } - case Instruction::Kind::For: - // Recursively number the stuff in the body. - numberValuesInBlock(*cast(&inst)->getBody()); - break; } } } @@ -1404,8 +1368,6 @@ void FunctionPrinter::print(const Instruction *inst) { switch (inst->getKind()) { case Instruction::Kind::OperationInst: return print(cast(inst)); - case Instruction::Kind::For: - return print(cast(inst)); } } @@ -1415,24 +1377,6 @@ void FunctionPrinter::print(const OperationInst *inst) { printTrailingLocation(inst->getLoc()); } -void FunctionPrinter::print(const ForInst *inst) { - os.indent(currentIndent) << "for "; - printOperand(inst->getInductionVar()); - os << " = "; - printBound(inst->getLowerBound(), "max"); - os << " to "; - printBound(inst->getUpperBound(), "min"); - - if (inst->getStep() != 1) - os << " step " << inst->getStep(); - - printTrailingLocation(inst->getLoc()); - - os << " {\n"; - print(inst->getBody(), /*printBlockArgs=*/false); - os.indent(currentIndent) << "}"; -} - void FunctionPrinter::printValueID(const Value *value, bool printResultNo) const { int resultNo = -1; @@ -1560,62 +1504,6 @@ void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term, os << ')'; } -void FunctionPrinter::printDimAndSymbolList(ArrayRef ops, - unsigned numDims) { - auto printComma = [&]() { os << ", "; }; - os << '('; - interleave( - ops.begin(), ops.begin() + numDims, - [&](const InstOperand &v) { printOperand(v.get()); }, printComma); - os << ')'; - - if (numDims < ops.size()) { - os << '['; - interleave( - ops.begin() + numDims, ops.end(), - [&](const InstOperand &v) { printOperand(v.get()); }, printComma); - os << ']'; - } -} - -void FunctionPrinter::printBound(AffineBound bound, const char *prefix) { - AffineMap map = bound.getMap(); - - // Check if this bound should be printed using custom assembly form. - // The decision to restrict printing custom assembly form to trivial cases - // comes from the will to roundtrip MLIR binary -> text -> binary in a - // lossless way. - // Therefore, custom assembly form parsing and printing is only supported for - // zero-operand constant maps and single symbol operand identity maps. - if (map.getNumResults() == 1) { - AffineExpr expr = map.getResult(0); - - // Print constant bound. - if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { - if (auto constExpr = expr.dyn_cast()) { - os << constExpr.getValue(); - return; - } - } - - // Print bound that consists of a single SSA symbol if the map is over a - // single symbol. - if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { - if (auto symExpr = expr.dyn_cast()) { - printOperand(bound.getOperand(0)); - return; - } - } - } else { - // Map has multiple results. Print 'min' or 'max' prefix. - os << prefix << ' '; - } - - // Print the map and its operands. - printAffineMapReference(map); - printDimAndSymbolList(bound.getInstOperands(), map.getNumDims()); -} - // Prints function with initialized module state. void ModulePrinter::print(const Function *fn) { FunctionPrinter(fn, *this).print(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index ffeb4e0317fe..68fbef2d27a7 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -312,19 +312,3 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) { block->getInstructions().insert(insertPoint, op); return op; } - -ForInst *FuncBuilder::createFor(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, - AffineMap ubMap, int64_t step) { - auto *inst = - ForInst::create(location, lbOperands, lbMap, ubOperands, ubMap, step); - block->getInstructions().insert(insertPoint, inst); - return inst; -} - -ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, - int64_t step) { - auto lbMap = AffineMap::getConstantMap(lb, context); - auto ubMap = AffineMap::getConstantMap(ub, context); - return createFor(location, {}, lbMap, {}, ubMap, step); -} diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 8d43e3a783d4..03f1a2702c9a 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -143,9 +143,6 @@ void Instruction::destroy() { case Kind::OperationInst: cast(this)->destroy(); break; - case Kind::For: - cast(this)->destroy(); - break; } } @@ -209,8 +206,6 @@ unsigned Instruction::getNumOperands() const { switch (getKind()) { case Kind::OperationInst: return cast(this)->getNumOperands(); - case Kind::For: - return cast(this)->getNumOperands(); } } @@ -218,8 +213,6 @@ MutableArrayRef Instruction::getInstOperands() { switch (getKind()) { case Kind::OperationInst: return cast(this)->getInstOperands(); - case Kind::For: - return cast(this)->getInstOperands(); } } @@ -349,10 +342,6 @@ void Instruction::dropAllReferences() { op.drop(); switch (getKind()) { - case Kind::For: - // Make sure to drop references held by instructions within the body. - cast(this)->getBody()->dropAllReferences(); - break; case Kind::OperationInst: { auto *opInst = cast(this); if (isTerminator()) @@ -655,217 +644,6 @@ bool OperationInst::emitOpError(const Twine &message) const { return emitError(Twine('\'') + getName().getStringRef() + "' op " + message); } -//===----------------------------------------------------------------------===// -// ForInst -//===----------------------------------------------------------------------===// - -ForInst *ForInst::create(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, - AffineMap ubMap, int64_t step) { - assert((!lbMap && lbOperands.empty()) || - lbOperands.size() == lbMap.getNumInputs() && - "lower bound operand count does not match the affine map"); - assert((!ubMap && ubOperands.empty()) || - ubOperands.size() == ubMap.getNumInputs() && - "upper bound operand count does not match the affine map"); - assert(step > 0 && "step has to be a positive integer constant"); - - // Compute the byte size for the instruction and the operand storage. - unsigned numOperands = lbOperands.size() + ubOperands.size(); - auto byteSize = totalSizeToAlloc( - /*detail::OperandStorage*/ 1); - byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize( - numOperands, /*resizable=*/true), - alignof(ForInst)); - void *rawMem = malloc(byteSize); - - // Initialize the OperationInst part of the instruction. - ForInst *inst = ::new (rawMem) ForInst(location, lbMap, ubMap, step); - new (&inst->getOperandStorage()) - detail::OperandStorage(numOperands, /*resizable=*/true); - - auto operands = inst->getInstOperands(); - unsigned i = 0; - for (unsigned e = lbOperands.size(); i != e; ++i) - new (&operands[i]) InstOperand(inst, lbOperands[i]); - - for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j) - new (&operands[i]) InstOperand(inst, ubOperands[j]); - - return inst; -} - -ForInst::ForInst(Location location, AffineMap lbMap, AffineMap ubMap, - int64_t step) - : Instruction(Instruction::Kind::For, location), body(this), lbMap(lbMap), - ubMap(ubMap), step(step) { - - // The body of a for inst always has one block. - auto *bodyEntry = new Block(); - body.push_back(bodyEntry); - - // Add an argument to the block for the induction variable. - bodyEntry->addArgument(Type::getIndex(lbMap.getResult(0).getContext())); -} - -ForInst::~ForInst() { getOperandStorage().~OperandStorage(); } - -const AffineBound ForInst::getLowerBound() const { - return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap); -} - -const AffineBound ForInst::getUpperBound() const { - return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap); -} - -void ForInst::setLowerBound(ArrayRef lbOperands, AffineMap map) { - assert(lbOperands.size() == map.getNumInputs()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - - SmallVector newOperands(lbOperands.begin(), lbOperands.end()); - - auto ubOperands = getUpperBoundOperands(); - newOperands.append(ubOperands.begin(), ubOperands.end()); - getOperandStorage().setOperands(this, newOperands); - - this->lbMap = map; -} - -void ForInst::setUpperBound(ArrayRef ubOperands, AffineMap map) { - assert(ubOperands.size() == map.getNumInputs()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - - SmallVector newOperands(getLowerBoundOperands()); - newOperands.append(ubOperands.begin(), ubOperands.end()); - getOperandStorage().setOperands(this, newOperands); - - this->ubMap = map; -} - -void ForInst::setLowerBoundMap(AffineMap map) { - assert(lbMap.getNumDims() == map.getNumDims() && - lbMap.getNumSymbols() == map.getNumSymbols()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - this->lbMap = map; -} - -void ForInst::setUpperBoundMap(AffineMap map) { - assert(ubMap.getNumDims() == map.getNumDims() && - ubMap.getNumSymbols() == map.getNumSymbols()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - this->ubMap = map; -} - -bool ForInst::hasConstantLowerBound() const { return lbMap.isSingleConstant(); } - -bool ForInst::hasConstantUpperBound() const { return ubMap.isSingleConstant(); } - -int64_t ForInst::getConstantLowerBound() const { - return lbMap.getSingleConstantResult(); -} - -int64_t ForInst::getConstantUpperBound() const { - return ubMap.getSingleConstantResult(); -} - -void ForInst::setConstantLowerBound(int64_t value) { - setLowerBound({}, AffineMap::getConstantMap(value, getContext())); -} - -void ForInst::setConstantUpperBound(int64_t value) { - setUpperBound({}, AffineMap::getConstantMap(value, getContext())); -} - -ForInst::operand_range ForInst::getLowerBoundOperands() { - return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; -} - -ForInst::const_operand_range ForInst::getLowerBoundOperands() const { - return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; -} - -ForInst::operand_range ForInst::getUpperBoundOperands() { - return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; -} - -ForInst::const_operand_range ForInst::getUpperBoundOperands() const { - return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; -} - -bool ForInst::matchingBoundOperandList() const { - if (lbMap.getNumDims() != ubMap.getNumDims() || - lbMap.getNumSymbols() != ubMap.getNumSymbols()) - return false; - - unsigned numOperands = lbMap.getNumInputs(); - for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { - // Compare Value *'s. - if (getOperand(i) != getOperand(numOperands + i)) - return false; - } - return true; -} - -void ForInst::walkOps(std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitOperationInst(OperationInst *opInst) { callback(opInst); } - }; - - Walker w(callback); - w.walk(this); -} - -void ForInst::walkOpsPostOrder(std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitOperationInst(OperationInst *opInst) { callback(opInst); } - }; - - Walker v(callback); - v.walkPostOrder(this); -} - -/// Returns the induction variable for this loop. -Value *ForInst::getInductionVar() { return getBody()->getArgument(0); } - -void ForInst::destroy() { - this->~ForInst(); - free(this); -} - -/// Returns if the provided value is the induction variable of a ForInst. -bool mlir::isForInductionVar(const Value *val) { - return getForInductionVarOwner(val) != nullptr; -} - -/// Returns the loop parent of an induction variable. If the provided value is -/// not an induction variable, then return nullptr. -ForInst *mlir::getForInductionVarOwner(Value *val) { - const BlockArgument *ivArg = dyn_cast(val); - if (!ivArg || !ivArg->getOwner()) - return nullptr; - return dyn_cast_or_null( - ivArg->getOwner()->getParent()->getContainingInst()); -} -const ForInst *mlir::getForInductionVarOwner(const Value *val) { - return getForInductionVarOwner(const_cast(val)); -} - -/// Extracts the induction variables from a list of ForInsts and returns them. -SmallVector -mlir::extractForInductionVars(ArrayRef forInsts) { - SmallVector results; - for (auto *forInst : forInsts) - results.push_back(forInst->getInductionVar()); - return results; -} //===----------------------------------------------------------------------===// // Instruction Cloning //===----------------------------------------------------------------------===// @@ -879,84 +657,59 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, MLIRContext *context) const { SmallVector operands; SmallVector successors; - if (auto *opInst = dyn_cast(this)) { - operands.reserve(getNumOperands() + opInst->getNumSuccessors()); - if (!opInst->isTerminator()) { - // Non-terminators just add all the operands. - for (auto *opValue : getOperands()) + auto *opInst = cast(this); + operands.reserve(getNumOperands() + opInst->getNumSuccessors()); + + if (!opInst->isTerminator()) { + // Non-terminators just add all the operands. + for (auto *opValue : getOperands()) + operands.push_back(mapper.lookupOrDefault(const_cast(opValue))); + } else { + // We add the operands separated by nullptr's for each successor. + unsigned firstSuccOperand = opInst->getNumSuccessors() + ? opInst->getSuccessorOperandIndex(0) + : opInst->getNumOperands(); + auto InstOperands = opInst->getInstOperands(); + + unsigned i = 0; + for (; i != firstSuccOperand; ++i) + operands.push_back( + mapper.lookupOrDefault(const_cast(InstOperands[i].get()))); + + successors.reserve(opInst->getNumSuccessors()); + for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e; ++succ) { + successors.push_back(mapper.lookupOrDefault( + const_cast(opInst->getSuccessor(succ)))); + + // Add sentinel to delineate successor operands. + operands.push_back(nullptr); + + // Remap the successors operands. + for (auto *operand : opInst->getSuccessorOperands(succ)) operands.push_back( - mapper.lookupOrDefault(const_cast(opValue))); - } else { - // We add the operands separated by nullptr's for each successor. - unsigned firstSuccOperand = opInst->getNumSuccessors() - ? opInst->getSuccessorOperandIndex(0) - : opInst->getNumOperands(); - auto InstOperands = opInst->getInstOperands(); - - unsigned i = 0; - for (; i != firstSuccOperand; ++i) - operands.push_back( - mapper.lookupOrDefault(const_cast(InstOperands[i].get()))); - - successors.reserve(opInst->getNumSuccessors()); - for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e; - ++succ) { - successors.push_back(mapper.lookupOrDefault( - const_cast(opInst->getSuccessor(succ)))); - - // Add sentinel to delineate successor operands. - operands.push_back(nullptr); - - // Remap the successors operands. - for (auto *operand : opInst->getSuccessorOperands(succ)) - operands.push_back( - mapper.lookupOrDefault(const_cast(operand))); - } + mapper.lookupOrDefault(const_cast(operand))); } - - SmallVector resultTypes; - resultTypes.reserve(opInst->getNumResults()); - for (auto *result : opInst->getResults()) - resultTypes.push_back(result->getType()); - - unsigned numBlockLists = opInst->getNumBlockLists(); - auto *newOp = OperationInst::create( - getLoc(), opInst->getName(), operands, resultTypes, opInst->getAttrs(), - successors, numBlockLists, opInst->hasResizableOperandsList(), context); - - // Clone the block lists. - for (unsigned i = 0; i != numBlockLists; ++i) - opInst->getBlockList(i).cloneInto(&newOp->getBlockList(i), mapper, - context); - - // Remember the mapping of any results. - for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i) - mapper.map(opInst->getResult(i), newOp->getResult(i)); - return newOp; } - operands.reserve(getNumOperands()); - for (auto *opValue : getOperands()) - operands.push_back(mapper.lookupOrDefault(const_cast(opValue))); + SmallVector resultTypes; + resultTypes.reserve(opInst->getNumResults()); + for (auto *result : opInst->getResults()) + resultTypes.push_back(result->getType()); - // Otherwise, this must be a ForInst. - auto *forInst = cast(this); - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + unsigned numBlockLists = opInst->getNumBlockLists(); + auto *newOp = OperationInst::create( + getLoc(), opInst->getName(), operands, resultTypes, opInst->getAttrs(), + successors, numBlockLists, opInst->hasResizableOperandsList(), context); - auto *newFor = ForInst::create( - getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), - lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), ubMap, - forInst->getStep()); + // Clone the block lists. + for (unsigned i = 0; i != numBlockLists; ++i) + opInst->getBlockList(i).cloneInto(&newOp->getBlockList(i), mapper, context); - // Remember the induction variable mapping. - mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); - - // Recursively clone the body of the for loop. - for (auto &subInst : *forInst->getBody()) - newFor->getBody()->push_back(subInst.clone(mapper, context)); - return newFor; + // Remember the mapping of any results. + for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i) + mapper.map(opInst->getResult(i), newOp->getResult(i)); + return newOp; } Instruction *Instruction::clone(MLIRContext *context) const { diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 7103eeb7389e..a9c046dc7b1f 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -64,8 +64,6 @@ MLIRContext *IROperandOwner::getContext() const { switch (getKind()) { case Kind::OperationInst: return cast(this)->getContext(); - case Kind::ForInst: - return cast(this)->getContext(); } } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index f0c140166ed5..a9c62767734b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2128,23 +2128,6 @@ public: parseSuccessors(SmallVectorImpl &destinations, SmallVectorImpl> &operands); - ParseResult - parseOptionalBlockArgList(SmallVectorImpl &results, - Block *owner); - - ParseResult parseOperationBlockList(SmallVectorImpl &results); - ParseResult parseBlockListBody(SmallVectorImpl &results); - ParseResult parseBlock(Block *&block); - ParseResult parseBlockBody(Block *block); - - /// Cleans up the memory for allocated blocks when a parser error occurs. - void cleanupInvalidBlocks(ArrayRef invalidBlocks) { - // Add the referenced blocks to the function so that they can be properly - // cleaned up when the function is destroyed. - for (auto *block : invalidBlocks) - function->push_back(block); - } - /// After the function is finished parsing, this function checks to see if /// there are any remaining issues. ParseResult finalizeFunction(SMLoc loc); @@ -2187,6 +2170,25 @@ public: // Block references. + ParseResult + parseOperationBlockList(SmallVectorImpl &results, + ArrayRef> entryArguments); + ParseResult parseBlockListBody(SmallVectorImpl &results); + ParseResult parseBlock(Block *&block); + ParseResult parseBlockBody(Block *block); + + ParseResult + parseOptionalBlockArgList(SmallVectorImpl &results, + Block *owner); + + /// Cleans up the memory for allocated blocks when a parser error occurs. + void cleanupInvalidBlocks(ArrayRef invalidBlocks) { + // Add the referenced blocks to the function so that they can be properly + // cleaned up when the function is destroyed. + for (auto *block : invalidBlocks) + function->push_back(block); + } + /// Get the block with the specified name, creating it if it doesn't /// already exist. The location specified is the point of use, which allows /// us to diagnose references to blocks that are not defined precisely. @@ -2201,13 +2203,6 @@ public: OperationInst *parseGenericOperation(); OperationInst *parseCustomOperation(); - ParseResult parseForInst(); - ParseResult parseIntConstant(int64_t &val); - ParseResult parseDimAndSymbolList(SmallVectorImpl &operands, - unsigned numDims, unsigned numOperands, - const char *affineStructName); - ParseResult parseBound(SmallVectorImpl &operands, AffineMap &map, - bool isLower); ParseResult parseInstructions(Block *block); private: @@ -2287,25 +2282,43 @@ ParseResult FunctionParser::parseFunctionBody(bool hadNamedArguments) { /// /// block-list ::= '{' block-list-body /// -ParseResult -FunctionParser::parseOperationBlockList(SmallVectorImpl &results) { +ParseResult FunctionParser::parseOperationBlockList( + SmallVectorImpl &results, + ArrayRef> entryArguments) { // Parse the '{'. if (parseToken(Token::l_brace, "expected '{' to begin block list")) return ParseFailure; + // Check for an empty block list. - if (consumeIf(Token::r_brace)) + if (entryArguments.empty() && consumeIf(Token::r_brace)) return ParseSuccess; Block *currentBlock = builder.getInsertionBlock(); // Parse the first block directly to allow for it to be unnamed. Block *block = new Block(); + + // Add arguments to the entry block. + for (auto &placeholderArgPair : entryArguments) + if (addDefinition(placeholderArgPair.first, + block->addArgument(placeholderArgPair.second))) { + delete block; + return ParseFailure; + } + if (parseBlock(block)) { - cleanupInvalidBlocks(block); + delete block; return ParseFailure; } - results.push_back(block); + + // Verify that no other arguments were parsed. + if (!entryArguments.empty() && + block->getNumArguments() > entryArguments.size()) { + delete block; + return emitError("entry block arguments were already defined"); + } // Parse the rest of the block list. + results.push_back(block); if (parseBlockListBody(results)) return ParseFailure; @@ -2385,10 +2398,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) { if (parseOperation()) return ParseFailure; break; - case Token::kw_for: - if (parseForInst()) - return ParseFailure; - break; } } @@ -2859,7 +2868,7 @@ OperationInst *FunctionParser::parseGenericOperation() { std::vector> blocks; while (getToken().is(Token::l_brace)) { SmallVector newBlocks; - if (parseOperationBlockList(newBlocks)) { + if (parseOperationBlockList(newBlocks, /*entryArguments=*/llvm::None)) { for (auto &blockList : blocks) cleanupInvalidBlocks(blockList); return nullptr; @@ -2884,6 +2893,27 @@ public: CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser) : nameLoc(nameLoc), opName(opName), parser(parser) {} + bool parseOperation(const AbstractOperation *opDefinition, + OperationState *opState) { + if (opDefinition->parseAssembly(this, opState)) + return true; + + // Check that enough block lists were reserved for those that were parsed. + if (parsedBlockLists.size() > opState->numBlockLists) { + return emitError( + nameLoc, + "parsed more block lists than those reserved in the operation state"); + } + + // Check there were no dangling entry block arguments. + if (!parsedBlockListEntryArguments.empty()) { + return emitError( + nameLoc, + "no block list was attached to parsed entry block arguments"); + } + return false; + } + //===--------------------------------------------------------------------===// // High level parsing methods. //===--------------------------------------------------------------------===// @@ -2895,6 +2925,9 @@ public: bool parseComma() override { return parser.parseToken(Token::comma, "expected ','"); } + bool parseEqual() override { + return parser.parseToken(Token::equal, "expected '='"); + } bool parseType(Type &result) override { return !(result = parser.parseType()); @@ -3083,13 +3116,35 @@ public: /// Parses a list of blocks. bool parseBlockList() override { + // Parse the block list. SmallVector results; - if (parser.parseOperationBlockList(results)) + if (parser.parseOperationBlockList(results, parsedBlockListEntryArguments)) return true; + + parsedBlockListEntryArguments.clear(); parsedBlockLists.emplace_back(results); return false; } + /// Parses an argument for the entry block of the next block list to be + /// parsed. + bool parseBlockListEntryBlockArgument(Type argType) override { + SmallVector argValues; + OperandType operand; + if (parseOperand(operand)) + return true; + + // Create a place holder for this argument. + FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, + operand.location}; + if (auto *value = parser.resolveSSAUse(operandInfo, argType)) { + parsedBlockListEntryArguments.emplace_back(operandInfo, argType); + return false; + } + + return true; + } + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// @@ -3130,6 +3185,8 @@ public: private: std::vector> parsedBlockLists; + SmallVector, 2> + parsedBlockListEntryArguments; SMLoc nameLoc; StringRef opName; FunctionParser &parser; @@ -3161,26 +3218,18 @@ OperationInst *FunctionParser::parseCustomOperation() { // Have the op implementation take a crack and parsing this. OperationState opState(builder.getContext(), srcLocation, opName); - if (opDefinition->parseAssembly(&opAsmParser, &opState)) + if (opAsmParser.parseOperation(opDefinition, &opState)) return nullptr; // If it emitted an error, we failed. if (opAsmParser.didEmitError()) return nullptr; - // Check that enough block lists were reserved for those that were parsed. - auto parsedBlockLists = opAsmParser.getParsedBlockLists(); - if (parsedBlockLists.size() > opState.numBlockLists) { - opAsmParser.emitError( - opLoc, - "parsed more block lists than those reserved in the operation state"); - return nullptr; - } - // Otherwise, we succeeded. Use the state it parsed as our op information. auto *opInst = builder.createOperation(opState); // Resolve any parsed block lists. + auto parsedBlockLists = opAsmParser.getParsedBlockLists(); for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) { auto &opBlockList = opInst->getBlockList(i).getBlocks(); opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(), @@ -3189,213 +3238,6 @@ OperationInst *FunctionParser::parseCustomOperation() { return opInst; } -/// For instruction. -/// -/// ml-for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound -/// (`step` integer-literal)? trailing-location? `{` inst* `}` -/// -ParseResult FunctionParser::parseForInst() { - consumeToken(Token::kw_for); - - // Parse induction variable. - if (getToken().isNot(Token::percent_identifier)) - return emitError("expected SSA identifier for the loop variable"); - - auto loc = getToken().getLoc(); - StringRef inductionVariableName = getTokenSpelling(); - consumeToken(Token::percent_identifier); - - if (parseToken(Token::equal, "expected '='")) - return ParseFailure; - - // Parse lower bound. - SmallVector lbOperands; - AffineMap lbMap; - if (parseBound(lbOperands, lbMap, /*isLower*/ true)) - return ParseFailure; - - if (parseToken(Token::kw_to, "expected 'to' between bounds")) - return ParseFailure; - - // Parse upper bound. - SmallVector ubOperands; - AffineMap ubMap; - if (parseBound(ubOperands, ubMap, /*isLower*/ false)) - return ParseFailure; - - // Parse step. - int64_t step = 1; - if (consumeIf(Token::kw_step) && parseIntConstant(step)) - return ParseFailure; - - // The loop step is a positive integer constant. Since index is stored as an - // int64_t type, we restrict step to be in the set of positive integers that - // int64_t can represent. - if (step < 1) { - return emitError("step has to be a positive integer"); - } - - // Create for instruction. - ForInst *forInst = - builder.createFor(getEncodedSourceLocation(loc), lbOperands, lbMap, - ubOperands, ubMap, step); - - // Create SSA value definition for the induction variable. - if (addDefinition({inductionVariableName, 0, loc}, - forInst->getInductionVar())) - return ParseFailure; - - // Try to parse the optional trailing location. - if (parseOptionalTrailingLocation(forInst)) - return ParseFailure; - - // If parsing of the for instruction body fails, - // MLIR contains for instruction with those nested instructions that have been - // successfully parsed. - auto *forBody = forInst->getBody(); - if (parseToken(Token::l_brace, "expected '{' before instruction list") || - parseBlock(forBody) || - parseToken(Token::r_brace, "expected '}' after instruction list")) - return ParseFailure; - - // Reset insertion point to the current block. - builder.setInsertionPointToEnd(forInst->getBlock()); - - return ParseSuccess; -} - -/// Parse integer constant as affine constant expression. -ParseResult FunctionParser::parseIntConstant(int64_t &val) { - bool negate = consumeIf(Token::minus); - - if (getToken().isNot(Token::integer)) - return emitError("expected integer"); - - auto uval = getToken().getUInt64IntegerValue(); - - if (!uval.hasValue() || (int64_t)uval.getValue() < 0) { - return emitError("bound or step is too large for index"); - } - - val = (int64_t)uval.getValue(); - if (negate) - val = -val; - consumeToken(); - - return ParseSuccess; -} - -/// Dimensions and symbol use list. -/// -/// dim-use-list ::= `(` ssa-use-list? `)` -/// symbol-use-list ::= `[` ssa-use-list? `]` -/// dim-and-symbol-use-list ::= dim-use-list symbol-use-list? -/// -ParseResult -FunctionParser::parseDimAndSymbolList(SmallVectorImpl &operands, - unsigned numDims, unsigned numOperands, - const char *affineStructName) { - if (parseToken(Token::l_paren, "expected '('")) - return ParseFailure; - - SmallVector opInfo; - parseOptionalSSAUseList(opInfo); - - if (parseToken(Token::r_paren, "expected ')'")) - return ParseFailure; - - if (numDims != opInfo.size()) - return emitError("dim operand count and " + Twine(affineStructName) + - " dim count must match"); - - if (consumeIf(Token::l_square)) { - parseOptionalSSAUseList(opInfo); - if (parseToken(Token::r_square, "expected ']'")) - return ParseFailure; - } - - if (numOperands != opInfo.size()) - return emitError("symbol operand count and " + Twine(affineStructName) + - " symbol count must match"); - - // Resolve SSA uses. - Type indexType = builder.getIndexType(); - for (unsigned i = 0, e = opInfo.size(); i != e; ++i) { - Value *sval = resolveSSAUse(opInfo[i], indexType); - if (!sval) - return ParseFailure; - - if (i < numDims && !sval->isValidDim()) - return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + - "' cannot be used as a dimension id"); - if (i >= numDims && !sval->isValidSymbol()) - return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + - "' cannot be used as a symbol"); - operands.push_back(sval); - } - - return ParseSuccess; -} - -// Loop bound. -/// -/// lower-bound ::= `max`? affine-map dim-and-symbol-use-list | -/// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list -/// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal -/// -ParseResult FunctionParser::parseBound(SmallVectorImpl &operands, - AffineMap &map, bool isLower) { - // 'min' / 'max' prefixes are syntactic sugar. Ignore them. - if (isLower) - consumeIf(Token::kw_max); - else - consumeIf(Token::kw_min); - - // Parse full form - affine map followed by dim and symbol list. - if (getToken().isAny(Token::hash_identifier, Token::l_paren)) { - map = parseAffineMapReference(); - if (!map) - return ParseFailure; - - if (parseDimAndSymbolList(operands, map.getNumDims(), map.getNumInputs(), - "affine map")) - return ParseFailure; - return ParseSuccess; - } - - // Parse custom assembly form. - if (getToken().isAny(Token::minus, Token::integer)) { - int64_t val; - if (!parseIntConstant(val)) { - map = builder.getConstantAffineMap(val); - return ParseSuccess; - } - return ParseFailure; - } - - // Parse ssa-id as identity map. - SSAUseInfo opInfo; - if (parseSSAUse(opInfo)) - return ParseFailure; - - // TODO: improve error message when SSA value is not an affine integer. - // Currently it is 'use of value ... expects different type than prior uses' - if (auto *value = resolveSSAUse(opInfo, builder.getIndexType())) - operands.push_back(value); - else - return ParseFailure; - - // Create an identity map using dim id for an induction variable and - // symbol otherwise. This representation is optimized for storage. - // Analysis passes may expand it into a multi-dimensional map if desired. - if (isForInductionVar(operands[0])) - map = builder.getDimIdentityMap(); - else - map = builder.getSymbolIdentityMap(); - - return ParseSuccess; -} - /// Parse an affine constraint. /// affine-constraint ::= affine-expr `>=` `0` /// | affine-expr `==` `0` diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index afd18a49b793..e471b6792c59 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -183,11 +183,6 @@ void CSE::simplifyBlock(Block *bb) { } break; } - case Instruction::Kind::For: { - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(cast(i).getBody()); - break; - } } } } diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index f9d02f7a47aa..9c20e79180a6 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -15,6 +15,7 @@ // limitations under the License. // ============================================================================= +#include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/InstVisitor.h" @@ -37,7 +38,6 @@ struct ConstantFold : public FunctionPass, InstWalker { bool foldOperation(OperationInst *op, SmallVectorImpl &existingConstants); void visitOperationInst(OperationInst *inst); - void visitForInst(ForInst *inst); PassResult runOnFunction(Function *f) override; static char passID; @@ -50,6 +50,12 @@ char ConstantFold::passID = 0; /// constants are found, we keep track of them in the existingConstants list. /// void ConstantFold::visitOperationInst(OperationInst *op) { + // If this operation is an AffineForOp, then fold the bounds. + if (auto forOp = op->dyn_cast()) { + constantFoldBounds(forOp); + return; + } + // If this operation is already a constant, just remember it for cleanup // later, and don't try to fold it. if (auto constant = op->dyn_cast()) { @@ -98,11 +104,6 @@ void ConstantFold::visitOperationInst(OperationInst *op) { opInstsToErase.push_back(op); } -// Override the walker's 'for' instruction visit for constant folding. -void ConstantFold::visitForInst(ForInst *forInst) { - constantFoldBounds(forInst); -} - // For now, we do a simple top-down pass over a function folding constants. We // don't handle conditional control flow, block arguments, folding // conditional branches, or anything else fancy. diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 5c3a66208ecb..83ec726ec2a3 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -21,6 +21,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" @@ -71,9 +72,9 @@ struct DmaGeneration : public FunctionPass { } PassResult runOnFunction(Function *f) override; - void runOnForInst(ForInst *forInst); + void runOnAffineForOp(OpPointer forOp); - bool generateDma(const MemRefRegion ®ion, ForInst *forInst, + bool generateDma(const MemRefRegion ®ion, OpPointer forOp, uint64_t *sizeInBytes); // List of memory regions to DMA for. We need a map vector to have a @@ -174,7 +175,7 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, // Just get the first numSymbols IVs, which the memref region is parametric // on. - SmallVector ivs; + SmallVector, 4> ivs; getLoopIVs(*opInst, &ivs); ivs.resize(numParamLoopIVs); SmallVector symbols = extractForInductionVars(ivs); @@ -195,8 +196,10 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, // generates a DMA from the lower memory space to this one, and replaces all // loads to load from that buffer. Returns false if DMAs could not be generated // due to yet unimplemented cases. -bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, +bool DmaGeneration::generateDma(const MemRefRegion ®ion, + OpPointer forOp, uint64_t *sizeInBytes) { + auto *forInst = forOp->getInstruction(); // DMAs for read regions are going to be inserted just before the for loop. FuncBuilder prologue(forInst); @@ -386,39 +389,43 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, remapExprs.push_back(dimExpr - offsets[i]); } auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); - // *Only* those uses within the body of 'forInst' are replaced. + // *Only* those uses within the body of 'forOp' are replaced. replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/outerIVs, - /*domInstFilter=*/&*forInst->getBody()->begin()); + /*domInstFilter=*/&*forOp->getBody()->begin()); return true; } // TODO(bondhugula): make this run on a Block instead of a 'for' inst. -void DmaGeneration::runOnForInst(ForInst *forInst) { +void DmaGeneration::runOnAffineForOp(OpPointer forOp) { // For now (for testing purposes), we'll run this on the outermost among 'for' // inst's with unit stride, i.e., right at the top of the tile if tiling has // been done. In the future, the DMA generation has to be done at a level // where the generated data fits in a higher level of the memory hierarchy; so // the pass has to be instantiated with additional information that we aren't // provided with at the moment. - if (forInst->getStep() != 1) { - if (auto *innerFor = dyn_cast(&*forInst->getBody()->begin())) { - runOnForInst(innerFor); + if (forOp->getStep() != 1) { + auto *forBody = forOp->getBody(); + if (forBody->empty()) + return; + if (auto innerFor = + cast(forBody->front()).dyn_cast()) { + runOnAffineForOp(innerFor); } return; } // DMAs will be generated for this depth, i.e., for all data accessed by this // loop. - unsigned dmaDepth = getNestingDepth(*forInst); + unsigned dmaDepth = getNestingDepth(*forOp->getInstruction()); readRegions.clear(); writeRegions.clear(); fastBufferMap.clear(); // Walk this 'for' instruction to gather all memory regions. - forInst->walkOps([&](OperationInst *opInst) { + forOp->walkOps([&](OperationInst *opInst) { // Gather regions to promote to buffers in faster memory space. // TODO(bondhugula): handle store op's; only load's handled for now. if (auto loadOp = opInst->dyn_cast()) { @@ -443,7 +450,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n"); if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) { LLVM_DEBUG( - forInst->emitError("Non-constant memref sizes not yet supported")); + forOp->emitError("Non-constant memref sizes not yet supported")); return; } } @@ -472,10 +479,10 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { // Perform a union with the existing region. if (!(*it).second->unionBoundingBox(*region)) { LLVM_DEBUG(llvm::dbgs() - << "Memory region bounding box failed; " + << "Memory region bounding box failed" "over-approximating to the entire memref\n"); if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) { - LLVM_DEBUG(forInst->emitError( + LLVM_DEBUG(forOp->emitError( "Non-constant memref sizes not yet supported")); } } @@ -501,7 +508,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { ®ions) { for (const auto ®ionEntry : regions) { uint64_t sizeInBytes; - bool iRet = generateDma(*regionEntry.second, forInst, &sizeInBytes); + bool iRet = generateDma(*regionEntry.second, forOp, &sizeInBytes); if (iRet) totalSizeInBytes += sizeInBytes; ret = ret & iRet; @@ -510,7 +517,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { processRegions(readRegions); processRegions(writeRegions); if (!ret) { - forInst->emitError("DMA generation failed for one or more memref's\n"); + forOp->emitError("DMA generation failed for one or more memref's\n"); return; } LLVM_DEBUG(llvm::dbgs() << Twine(llvm::divideCeil(totalSizeInBytes, 1024)) @@ -519,7 +526,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { if (clFastMemoryCapacity && totalSizeInBytes > clFastMemoryCapacity) { // TODO(bondhugula): selecting the DMA depth so that the result DMA buffers // fit in fast memory is a TODO - not complex. - forInst->emitError( + forOp->emitError( "Total size of all DMA buffers' exceeds memory capacity\n"); } } @@ -531,8 +538,8 @@ PassResult DmaGeneration::runOnFunction(Function *f) { for (auto &block : *f) { for (auto &inst : block) { - if (auto *forInst = dyn_cast(&inst)) { - runOnForInst(forInst); + if (auto forOp = cast(inst).dyn_cast()) { + runOnAffineForOp(forOp); } } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index fa0e3b51de32..7d4ff03e3069 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -97,15 +97,15 @@ namespace { // operations, and whether or not an IfInst was encountered in the loop nest. class LoopNestStateCollector : public InstWalker { public: - SmallVector forInsts; + SmallVector, 4> forOps; SmallVector loadOpInsts; SmallVector storeOpInsts; bool hasNonForRegion = false; - void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - void visitOperationInst(OperationInst *opInst) { - if (opInst->getNumBlockLists() != 0) + if (opInst->isa()) + forOps.push_back(opInst->cast()); + else if (opInst->getNumBlockLists() != 0) hasNonForRegion = true; else if (opInst->isa()) loadOpInsts.push_back(opInst); @@ -491,14 +491,14 @@ bool MemRefDependenceGraph::init(Function *f) { if (f->getBlocks().size() != 1) return false; - DenseMap forToNodeMap; + DenseMap forToNodeMap; for (auto &inst : f->front()) { - if (auto *forInst = dyn_cast(&inst)) { - // Create graph node 'id' to represent top-level 'forInst' and record + if (auto forOp = cast(&inst)->dyn_cast()) { + // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; - collector.walkForInst(forInst); - // Return false if IfInsts are found (not currently supported). + collector.walk(&inst); + // Return false if a non 'for' region was found (not currently supported). if (collector.hasNonForRegion) return false; Node node(nextNodeId++, &inst); @@ -512,10 +512,9 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); } - forToNodeMap[forInst] = node.id; + forToNodeMap[&inst] = node.id; nodes.insert({node.id, node}); - } - if (auto *opInst = dyn_cast(&inst)) { + } else if (auto *opInst = dyn_cast(&inst)) { if (auto loadOp = opInst->dyn_cast()) { // Create graph node for top-level load op. Node node(nextNodeId++, &inst); @@ -552,12 +551,12 @@ bool MemRefDependenceGraph::init(Function *f) { for (auto *value : opInst->getResults()) { for (auto &use : value->getUses()) { auto *userOpInst = cast(use.getOwner()); - SmallVector loops; + SmallVector, 4> loops; getLoopIVs(*userOpInst, &loops); if (loops.empty()) continue; - assert(forToNodeMap.count(loops[0]) > 0); - unsigned userLoopNestId = forToNodeMap[loops[0]]; + assert(forToNodeMap.count(loops[0]->getInstruction()) > 0); + unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()]; addEdge(node.id, userLoopNestId, value); } } @@ -587,12 +586,12 @@ namespace { // LoopNestStats aggregates various per-loop statistics (eg. loop trip count // and operation count) for a loop nest up until the innermost loop body. struct LoopNestStats { - // Map from ForInst to immediate child ForInsts in its loop body. - DenseMap> loopMap; - // Map from ForInst to count of operations in its loop body. - DenseMap opCountMap; - // Map from ForInst to its constant trip count. - DenseMap tripCountMap; + // Map from AffineForOp to immediate child AffineForOps in its loop body. + DenseMap, 2>> loopMap; + // Map from AffineForOp to count of operations in its loop body. + DenseMap opCountMap; + // Map from AffineForOp to its constant trip count. + DenseMap tripCountMap; }; // LoopNestStatsCollector walks a single loop nest and gathers per-loop @@ -604,23 +603,31 @@ public: LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - void visitForInst(ForInst *forInst) { - auto *parentInst = forInst->getParentInst(); + void visitOperationInst(OperationInst *opInst) { + auto forOp = opInst->dyn_cast(); + if (!forOp) + return; + + auto *forInst = forOp->getInstruction(); + auto *parentInst = forOp->getInstruction()->getParentInst(); if (parentInst != nullptr) { - assert(isa(parentInst) && "Expected parent ForInst"); - // Add mapping to 'forInst' from its parent ForInst. - stats->loopMap[cast(parentInst)].push_back(forInst); + assert(cast(parentInst)->isa() && + "Expected parent AffineForOp"); + // Add mapping to 'forOp' from its parent AffineForOp. + stats->loopMap[parentInst].push_back(forOp); } - // Record the number of op instructions in the body of 'forInst'. + + // Record the number of op instructions in the body of 'forOp'. unsigned count = 0; stats->opCountMap[forInst] = 0; - for (auto &inst : *forInst->getBody()) { - if (isa(&inst)) + for (auto &inst : *forOp->getBody()) { + if (!(cast(inst).isa() || + cast(inst).isa())) ++count; } stats->opCountMap[forInst] = count; - // Record trip count for 'forInst'. Set flag if trip count is not constant. - Optional maybeConstTripCount = getConstantTripCount(*forInst); + // Record trip count for 'forOp'. Set flag if trip count is not constant. + Optional maybeConstTripCount = getConstantTripCount(forOp); if (!maybeConstTripCount.hasValue()) { hasLoopWithNonConstTripCount = true; return; @@ -629,7 +636,7 @@ public: } }; -// Computes the total cost of the loop nest rooted at 'forInst'. +// Computes the total cost of the loop nest rooted at 'forOp'. // Currently, the total cost is computed by counting the total operation // instance count (i.e. total number of operations in the loop bodyloop // operation count * loop trip count) for the entire loop nest. @@ -637,7 +644,7 @@ public: // specified in the map when computing the total op instance count. // NOTE: this is used to compute the cost of computation slices, which are // sliced along the iteration dimension, and thus reduce the trip count. -// If 'computeCostMap' is non-null, the total op count for forInsts specified +// If 'computeCostMap' is non-null, the total op count for forOps specified // in the map is increased (not overridden) by adding the op count from the // map to the existing op count for the for loop. This is done before // multiplying by the loop's trip count, and is used to model the cost of @@ -645,15 +652,15 @@ public: // NOTE: this is used to compute the cost of fusing a slice of some loop nest // within another loop. static int64_t getComputeCost( - ForInst *forInst, LoopNestStats *stats, - llvm::SmallDenseMap *tripCountOverrideMap, - DenseMap *computeCostMap) { - // 'opCount' is the total number operations in one iteration of 'forInst' body + Instruction *forInst, LoopNestStats *stats, + llvm::SmallDenseMap *tripCountOverrideMap, + DenseMap *computeCostMap) { + // 'opCount' is the total number operations in one iteration of 'forOp' body int64_t opCount = stats->opCountMap[forInst]; if (stats->loopMap.count(forInst) > 0) { - for (auto *childForInst : stats->loopMap[forInst]) { - opCount += getComputeCost(childForInst, stats, tripCountOverrideMap, - computeCostMap); + for (auto childForOp : stats->loopMap[forInst]) { + opCount += getComputeCost(childForOp->getInstruction(), stats, + tripCountOverrideMap, computeCostMap); } } // Add in additional op instances from slice (if specified in map). @@ -694,18 +701,18 @@ static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { return cExpr.getValue(); } -// Builds a map 'tripCountMap' from ForInst to constant trip count for loop +// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop // nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'. // Returns true on success, false otherwise (if a non-constant trip count // was encountered). // TODO(andydavis) Make this work with non-unit step loops. static bool buildSliceTripCountMap( OperationInst *srcOpInst, ComputationSliceState *sliceState, - llvm::SmallDenseMap *tripCountMap) { - SmallVector srcLoopIVs; + llvm::SmallDenseMap *tripCountMap) { + SmallVector, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); - // Populate map from ForInst -> trip count + // Populate map from AffineForOp -> trip count for (unsigned i = 0; i < numSrcLoopIVs; ++i) { AffineMap lbMap = sliceState->lbs[i]; AffineMap ubMap = sliceState->ubs[i]; @@ -713,7 +720,7 @@ static bool buildSliceTripCountMap( // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. if (srcLoopIVs[i]->hasConstantLowerBound() && srcLoopIVs[i]->hasConstantUpperBound()) { - (*tripCountMap)[srcLoopIVs[i]] = + (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = srcLoopIVs[i]->getConstantUpperBound() - srcLoopIVs[i]->getConstantLowerBound(); continue; @@ -723,7 +730,7 @@ static bool buildSliceTripCountMap( Optional tripCount = getConstDifference(lbMap, ubMap); if (!tripCount.hasValue()) return false; - (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue(); + (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue(); } return true; } @@ -750,7 +757,7 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { unsigned numOps = ops.size(); assert(numOps > 0); - std::vector> loops(numOps); + std::vector, 4>> loops(numOps); unsigned loopDepthLimit = std::numeric_limits::max(); for (unsigned i = 0; i < numOps; ++i) { getLoopIVs(*ops[i], &loops[i]); @@ -762,9 +769,8 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { for (unsigned d = 0; d < loopDepthLimit; ++d) { unsigned i; for (i = 1; i < numOps; ++i) { - if (loops[i - 1][d] != loops[i][d]) { + if (loops[i - 1][d] != loops[i][d]) break; - } } if (i != numOps) break; @@ -871,14 +877,16 @@ static bool getSliceUnion(const ComputationSliceState &sliceStateA, } // Creates and returns a private (single-user) memref for fused loop rooted -// at 'forInst', with (potentially reduced) memref size based on the +// at 'forOp', with (potentially reduced) memref size based on the // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. // TODO(bondhugula): consider refactoring the common code from generateDma and // this one. -static Value *createPrivateMemRef(ForInst *forInst, +static Value *createPrivateMemRef(OpPointer forOp, OperationInst *srcStoreOpInst, unsigned dstLoopDepth) { - // Create builder to insert alloc op just before 'forInst'. + auto *forInst = forOp->getInstruction(); + + // Create builder to insert alloc op just before 'forOp'. FuncBuilder b(forInst); // Builder to create constants at the top level. FuncBuilder top(forInst->getFunction()); @@ -934,16 +942,16 @@ static Value *createPrivateMemRef(ForInst *forInst, for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) allocOperands.push_back( - top.create(forInst->getLoc(), oldMemRef, dynamicDimCount++)); + top.create(forOp->getLoc(), oldMemRef, dynamicDimCount++)); } - // Create new private memref for fused loop 'forInst'. + // Create new private memref for fused loop 'forOp'. // TODO(andydavis) Create/move alloc ops for private memrefs closer to their // consumer loop nests to reduce their live range. Currently they are added // at the beginning of the function, because loop nests can be reordered // during the fusion pass. Value *newMemRef = - top.create(forInst->getLoc(), newMemRefType, allocOperands); + top.create(forOp->getLoc(), newMemRefType, allocOperands); // Build an AffineMap to remap access functions based on lower bound offsets. SmallVector remapExprs; @@ -967,7 +975,7 @@ static Value *createPrivateMemRef(ForInst *forInst, bool ret = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, /*extraOperands=*/outerIVs, - /*domInstFilter=*/&*forInst->getBody()->begin()); + /*domInstFilter=*/&*forOp->getBody()->begin()); assert(ret && "replaceAllMemrefUsesWith should always succeed here"); (void)ret; return newMemRef; @@ -975,7 +983,7 @@ static Value *createPrivateMemRef(ForInst *forInst, // Does the slice have a single iteration? static uint64_t getSliceIterationCount( - const llvm::SmallDenseMap &sliceTripCountMap) { + const llvm::SmallDenseMap &sliceTripCountMap) { uint64_t iterCount = 1; for (const auto &count : sliceTripCountMap) { iterCount *= count.second; @@ -1030,25 +1038,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst, }); // Compute cost of sliced and unsliced src loop nest. - SmallVector srcLoopIVs; + SmallVector, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats); - srcStatsCollector.walk(srcLoopIVs[0]); + srcStatsCollector.walk(srcLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (srcStatsCollector.hasLoopWithNonConstTripCount) return false; // Compute cost of dst loop nest. - SmallVector dstLoopIVs; + SmallVector, 4> dstLoopIVs; getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); LoopNestStats dstLoopNestStats; LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); - dstStatsCollector.walk(dstLoopIVs[0]); + dstStatsCollector.walk(dstLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (dstStatsCollector.hasLoopWithNonConstTripCount) return false; @@ -1075,17 +1083,19 @@ static bool isFusionProfitable(OperationInst *srcOpInst, Optional bestDstLoopDepth = None; // Compute op instance count for the src loop nest without iteration slicing. - uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); + uint64_t srcLoopNestCost = + getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); // Compute op instance count for the src loop nest. - uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); + uint64_t dstLoopNestCost = + getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); - llvm::SmallDenseMap sliceTripCountMap; - DenseMap computeCostMap; + llvm::SmallDenseMap sliceTripCountMap; + DenseMap computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { MemRefAccess srcAccess(srcOpInst); // Handle the common case of one dst load without a copy. @@ -1121,24 +1131,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // The store and loads to this memref will disappear. if (storeLoadFwdGuaranteed) { // A single store disappears: -1 for that. - computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1; + computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1; for (auto *loadOp : dstLoadOpInsts) { - if (auto *loadLoop = dyn_cast_or_null(loadOp->getParentInst())) - computeCostMap[loadLoop] = -1; + auto *parentInst = loadOp->getParentInst(); + if (parentInst && cast(parentInst)->isa()) + computeCostMap[parentInst] = -1; } } // Compute op instance count for the src loop nest with iteration slicing. int64_t sliceComputeCost = - getComputeCost(srcLoopIVs[0], &srcLoopNestStats, + getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats, /*tripCountOverrideMap=*/&sliceTripCountMap, /*computeCostMap=*/&computeCostMap); // Compute cost of fusion for this depth. - computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost; + computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost; int64_t fusedLoopNestComputeCost = - getComputeCost(dstLoopIVs[0], &dstLoopNestStats, + getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats, /*tripCountOverrideMap=*/nullptr, &computeCostMap); double additionalComputeFraction = @@ -1211,8 +1222,8 @@ static bool isFusionProfitable(OperationInst *srcOpInst, << "\n fused loop nest compute cost: " << minFusedLoopNestComputeCost << "\n"); - auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]); - auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]); + auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]); + auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); Optional storageReduction = None; @@ -1292,9 +1303,9 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // // *) A worklist is initialized with node ids from the dependence graph. // *) For each node id in the worklist: -// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate -// destination ForInst into which fusion will be attempted. -// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'. +// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a +// candidate destination AffineForOp into which fusion will be attempted. +// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. // *) For each LoadOp in 'dstLoadOps' do: // *) Lookup dependent loop nests at earlier positions in the Function // which have a single store op to the same memref. @@ -1342,7 +1353,7 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!isa(dstNode->inst)) + if (!cast(dstNode->inst)->isa()) continue; SmallVector loads = dstNode->loads; @@ -1375,7 +1386,7 @@ public: // Get 'srcNode' from which to attempt fusion into 'dstNode'. auto *srcNode = mdg->getNode(srcId); // Skip if 'srcNode' is not a loop nest. - if (!isa(srcNode->inst)) + if (!cast(srcNode->inst)->isa()) continue; // Skip if 'srcNode' has more than one store to any memref. // TODO(andydavis) Support fusing multi-output src loop nests. @@ -1417,25 +1428,26 @@ public: continue; // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. - auto *sliceLoopNest = mlir::insertBackwardComputationSlice( + auto sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { - // Move 'dstForInst' before 'insertPointInst' if needed. - auto *dstForInst = cast(dstNode->inst); - if (insertPointInst != dstForInst) { - dstForInst->moveBefore(insertPointInst); + // Move 'dstAffineForOp' before 'insertPointInst' if needed. + auto dstAffineForOp = + cast(dstNode->inst)->cast(); + if (insertPointInst != dstAffineForOp->getInstruction()) { + dstAffineForOp->getInstruction()->moveBefore(insertPointInst); } // Update edges between 'srcNode' and 'dstNode'. mdg->updateEdges(srcNode->id, dstNode->id, memref); // Collect slice loop stats. LoopNestStateCollector sliceCollector; - sliceCollector.walkForInst(sliceLoopNest); + sliceCollector.walk(sliceLoopNest->getInstruction()); // Promote single iteration slice loops to single IV value. - for (auto *forInst : sliceCollector.forInsts) { - promoteIfSingleIteration(forInst); + for (auto forOp : sliceCollector.forOps) { + promoteIfSingleIteration(forOp); } - // Create private memref for 'memref' in 'dstForInst'. + // Create private memref for 'memref' in 'dstAffineForOp'. SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { if (storeOpInst->cast()->getMemRef() == memref) @@ -1443,7 +1455,7 @@ public: } assert(storesForMemref.size() == 1); auto *newMemRef = createPrivateMemRef( - dstForInst, storesForMemref[0], bestDstLoopDepth); + dstAffineForOp, storesForMemref[0], bestDstLoopDepth); visitedMemrefs.insert(newMemRef); // Create new node in dependence graph for 'newMemRef' alloc op. unsigned newMemRefNodeId = @@ -1453,7 +1465,7 @@ public: // Collect dst loop stats after memref privatizaton transformation. LoopNestStateCollector dstLoopCollector; - dstLoopCollector.walkForInst(dstForInst); + dstLoopCollector.walk(dstAffineForOp->getInstruction()); // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. @@ -1472,7 +1484,7 @@ public: // function. if (mdg->canRemoveNode(srcNode->id)) { mdg->removeNode(srcNode->id); - cast(srcNode->inst)->erase(); + srcNode->inst->erase(); } } } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 396fc8eb658c..f1ee7fd18533 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -60,16 +61,17 @@ char LoopTiling::passID = 0; /// Function. FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } -// Move the loop body of ForInst 'src' from 'src' into the specified location in -// destination's body. -static inline void moveLoopBody(ForInst *src, ForInst *dest, +// Move the loop body of AffineForOp 'src' from 'src' into the specified +// location in destination's body. +static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest, Block::iterator loc) { dest->getBody()->getInstructions().splice(loc, src->getBody()->getInstructions()); } -// Move the loop body of ForInst 'src' from 'src' to the start of dest's body. -static inline void moveLoopBody(ForInst *src, ForInst *dest) { +// Move the loop body of AffineForOp 'src' from 'src' to the start of dest's +// body. +static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest) { moveLoopBody(src, dest, dest->getBody()->begin()); } @@ -78,13 +80,14 @@ static inline void moveLoopBody(ForInst *src, ForInst *dest) { /// depend on other dimensions. Bounds of each dimension can thus be treated /// independently, and deriving the new bounds is much simpler and faster /// than for the case of tiling arbitrary polyhedral shapes. -static void constructTiledIndexSetHyperRect(ArrayRef origLoops, - ArrayRef newLoops, - ArrayRef tileSizes) { +static void constructTiledIndexSetHyperRect( + MutableArrayRef> origLoops, + MutableArrayRef> newLoops, + ArrayRef tileSizes) { assert(!origLoops.empty()); assert(origLoops.size() == tileSizes.size()); - FuncBuilder b(origLoops[0]); + FuncBuilder b(origLoops[0]->getInstruction()); unsigned width = origLoops.size(); // Bounds for tile space loops. @@ -99,8 +102,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, } // Bounds for intra-tile loops. for (unsigned i = 0; i < width; i++) { - int64_t largestDiv = getLargestDivisorOfTripCount(*origLoops[i]); - auto mayBeConstantCount = getConstantTripCount(*origLoops[i]); + int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]); + auto mayBeConstantCount = getConstantTripCount(origLoops[i]); // The lower bound is just the tile-space loop. AffineMap lbMap = b.getDimIdentityMap(); newLoops[width + i]->setLowerBound( @@ -144,38 +147,40 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. // TODO(bondhugula): handle non hyper-rectangular spaces. -UtilResult mlir::tileCodeGen(ArrayRef band, +UtilResult mlir::tileCodeGen(MutableArrayRef> band, ArrayRef tileSizes) { assert(!band.empty()); assert(band.size() == tileSizes.size()); // Check if the supplied for inst's are all successively nested. for (unsigned i = 1, e = band.size(); i < e; i++) { - assert(band[i]->getParentInst() == band[i - 1]); + assert(band[i]->getInstruction()->getParentInst() == + band[i - 1]->getInstruction()); } auto origLoops = band; - ForInst *rootForInst = origLoops[0]; - auto loc = rootForInst->getLoc(); + OpPointer rootAffineForOp = origLoops[0]; + auto loc = rootAffineForOp->getLoc(); // Note that width is at least one since band isn't empty. unsigned width = band.size(); - SmallVector newLoops(2 * width); - ForInst *innermostPointLoop; + SmallVector, 12> newLoops(2 * width); + OpPointer innermostPointLoop; // The outermost among the loops as we add more.. - auto *topLoop = rootForInst; + auto *topLoop = rootAffineForOp->getInstruction(); // Add intra-tile (or point) loops. for (unsigned i = 0; i < width; i++) { FuncBuilder b(topLoop); // Loop bounds will be set later. - auto *pointLoop = b.createFor(loc, 0, 0); + auto pointLoop = b.create(loc, 0, 0); + pointLoop->createBody(); pointLoop->getBody()->getInstructions().splice( pointLoop->getBody()->begin(), topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - 1 - i] = pointLoop; - topLoop = pointLoop; + topLoop = pointLoop->getInstruction(); if (i == 0) innermostPointLoop = pointLoop; } @@ -184,12 +189,13 @@ UtilResult mlir::tileCodeGen(ArrayRef band, for (unsigned i = width; i < 2 * width; i++) { FuncBuilder b(topLoop); // Loop bounds will be set later. - auto *tileSpaceLoop = b.createFor(loc, 0, 0); + auto tileSpaceLoop = b.create(loc, 0, 0); + tileSpaceLoop->createBody(); tileSpaceLoop->getBody()->getInstructions().splice( tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - i - 1] = tileSpaceLoop; - topLoop = tileSpaceLoop; + topLoop = tileSpaceLoop->getInstruction(); } // Move the loop body of the original nest to the new one. @@ -201,8 +207,8 @@ UtilResult mlir::tileCodeGen(ArrayRef band, getIndexSet(band, &cst); if (!cst.isHyperRectangular(0, width)) { - rootForInst->emitError("tiled code generation unimplemented for the" - "non-hyperrectangular case"); + rootAffineForOp->emitError("tiled code generation unimplemented for the" + "non-hyperrectangular case"); return UtilResult::Failure; } @@ -213,7 +219,7 @@ UtilResult mlir::tileCodeGen(ArrayRef band, } // Erase the old loop nest. - rootForInst->erase(); + rootAffineForOp->erase(); return UtilResult::Success; } @@ -221,38 +227,36 @@ UtilResult mlir::tileCodeGen(ArrayRef band, // Identify valid and profitable bands of loops to tile. This is currently just // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. -static void getTileableBands(Function *f, - std::vector> *bands) { +static void +getTileableBands(Function *f, + std::vector, 6>> *bands) { // Get maximal perfect nest of 'for' insts starting from root (inclusive). - auto getMaximalPerfectLoopNest = [&](ForInst *root) { - SmallVector band; - ForInst *currInst = root; + auto getMaximalPerfectLoopNest = [&](OpPointer root) { + SmallVector, 6> band; + OpPointer currInst = root; do { band.push_back(currInst); } while (currInst->getBody()->getInstructions().size() == 1 && - (currInst = dyn_cast(&currInst->getBody()->front()))); + (currInst = cast(currInst->getBody()->front()) + .dyn_cast())); bands->push_back(band); }; - for (auto &block : *f) { - for (auto &inst : block) { - auto *forInst = dyn_cast(&inst); - if (!forInst) - continue; - getMaximalPerfectLoopNest(forInst); - } - } + for (auto &block : *f) + for (auto &inst : block) + if (auto forOp = cast(inst).dyn_cast()) + getMaximalPerfectLoopNest(forOp); } PassResult LoopTiling::runOnFunction(Function *f) { - std::vector> bands; + std::vector, 6>> bands; getTileableBands(f, &bands); // Temporary tile sizes. unsigned tileSize = clTileSize.getNumOccurrences() > 0 ? clTileSize : kDefaultTileSize; - for (const auto &band : bands) { + for (auto &band : bands) { SmallVector tileSizes(band.size(), tileSize); if (tileCodeGen(band, tileSizes)) { return failure(); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 6d63e4afd2d4..86e913bd71f0 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -21,6 +21,7 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -70,18 +71,19 @@ struct LoopUnroll : public FunctionPass { const Optional unrollFull; // Callback to obtain unroll factors; if this has a callable target, takes // precedence over command-line argument or passed argument. - const std::function getUnrollFactor; + const std::function)> getUnrollFactor; - explicit LoopUnroll( - Optional unrollFactor = None, Optional unrollFull = None, - const std::function &getUnrollFactor = nullptr) + explicit LoopUnroll(Optional unrollFactor = None, + Optional unrollFull = None, + const std::function)> + &getUnrollFactor = nullptr) : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} PassResult runOnFunction(Function *f) override; /// Unroll this for inst. Returns false if nothing was done. - bool runOnForInst(ForInst *forInst); + bool runOnAffineForOp(OpPointer forOp); static const unsigned kDefaultUnrollFactor = 4; @@ -96,7 +98,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { class InnermostLoopGatherer : public InstWalker { public: // Store innermost loops as we walk. - std::vector loops; + std::vector> loops; // This method specialized to encode custom return logic. using InstListType = llvm::iplist; @@ -111,20 +113,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return hasInnerLoops; } - bool walkForInstPostOrder(ForInst *forInst) { - bool hasInnerLoops = - walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end()); - if (!hasInnerLoops) - loops.push_back(forInst); - return true; - } - bool walkOpInstPostOrder(OperationInst *opInst) { + bool hasInnerLoops = false; for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) - if (walkPostOrder(block.begin(), block.end())) - return true; - return false; + hasInnerLoops |= walkPostOrder(block.begin(), block.end()); + if (opInst->isa()) { + if (!hasInnerLoops) + loops.push_back(opInst->cast()); + return true; + } + return hasInnerLoops; } // FIXME: can't use base class method for this because that in turn would @@ -137,14 +136,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) { class ShortLoopGatherer : public InstWalker { public: // Store short loops as we walk. - std::vector loops; + std::vector> loops; const unsigned minTripCount; ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitForInst(ForInst *forInst) { - Optional tripCount = getConstantTripCount(*forInst); + void visitOperationInst(OperationInst *opInst) { + auto forOp = opInst->dyn_cast(); + if (!forOp) + return; + Optional tripCount = getConstantTripCount(forOp); if (tripCount.hasValue() && tripCount.getValue() <= minTripCount) - loops.push_back(forInst); + loops.push_back(forOp); } }; @@ -156,8 +158,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) { // ones). slg.walkPostOrder(f); auto &loops = slg.loops; - for (auto *forInst : loops) - loopUnrollFull(forInst); + for (auto forOp : loops) + loopUnrollFull(forOp); return success(); } @@ -172,8 +174,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) { if (loops.empty()) break; bool unrolled = false; - for (auto *forInst : loops) - unrolled |= runOnForInst(forInst); + for (auto forOp : loops) + unrolled |= runOnAffineForOp(forOp); if (!unrolled) // Break out if nothing was unrolled. break; @@ -183,29 +185,30 @@ PassResult LoopUnroll::runOnFunction(Function *f) { /// Unrolls a 'for' inst. Returns true if the loop was unrolled, false /// otherwise. The default unroll factor is 4. -bool LoopUnroll::runOnForInst(ForInst *forInst) { +bool LoopUnroll::runOnAffineForOp(OpPointer forOp) { // Use the function callback if one was provided. if (getUnrollFactor) { - return loopUnrollByFactor(forInst, getUnrollFactor(*forInst)); + return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); } // Unroll by the factor passed, if any. if (unrollFactor.hasValue()) - return loopUnrollByFactor(forInst, unrollFactor.getValue()); + return loopUnrollByFactor(forOp, unrollFactor.getValue()); // Unroll by the command line factor if one was specified. if (clUnrollFactor.getNumOccurrences() > 0) - return loopUnrollByFactor(forInst, clUnrollFactor); + return loopUnrollByFactor(forOp, clUnrollFactor); // Unroll completely if full loop unroll was specified. if (clUnrollFull.getNumOccurrences() > 0 || (unrollFull.hasValue() && unrollFull.getValue())) - return loopUnrollFull(forInst); + return loopUnrollFull(forOp); // Unroll by four otherwise. - return loopUnrollByFactor(forInst, kDefaultUnrollFactor); + return loopUnrollByFactor(forOp, kDefaultUnrollFactor); } FunctionPass *mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, - const std::function &getUnrollFactor) { + const std::function)> + &getUnrollFactor) { return new LoopUnroll( unrollFactor == -1 ? None : Optional(unrollFactor), unrollFull == -1 ? None : Optional(unrollFull), getUnrollFactor); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 7deaf850362e..7327a37ee3ad 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -43,6 +43,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -80,7 +81,7 @@ struct LoopUnrollAndJam : public FunctionPass { unrollJamFactor(unrollJamFactor) {} PassResult runOnFunction(Function *f) override; - bool runOnForInst(ForInst *forInst); + bool runOnAffineForOp(OpPointer forOp); static char passID; }; @@ -95,47 +96,51 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { PassResult LoopUnrollAndJam::runOnFunction(Function *f) { // Currently, just the outermost loop from the first loop nest is - // unroll-and-jammed by this pass. However, runOnForInst can be called on any - // for Inst. + // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on + // any for Inst. auto &entryBlock = f->front(); if (!entryBlock.empty()) - if (auto *forInst = dyn_cast(&entryBlock.front())) - runOnForInst(forInst); + if (auto forOp = + cast(entryBlock.front()).dyn_cast()) + runOnAffineForOp(forOp); return success(); } /// Unroll and jam a 'for' inst. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return false if nothing was done. -bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) { +bool LoopUnrollAndJam::runOnAffineForOp(OpPointer forOp) { // Unroll and jam by the factor that was passed if any. if (unrollJamFactor.hasValue()) - return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue()); + return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue()); // Otherwise, unroll jam by the command-line factor if one was specified. if (clUnrollJamFactor.getNumOccurrences() > 0) - return loopUnrollJamByFactor(forInst, clUnrollJamFactor); + return loopUnrollJamByFactor(forOp, clUnrollJamFactor); // Unroll and jam by four otherwise. - return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor); + return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor); } -bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) { - Optional mayBeConstantTripCount = getConstantTripCount(*forInst); +bool mlir::loopUnrollJamUpToFactor(OpPointer forOp, + uint64_t unrollJamFactor) { + Optional mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollJamFactor) - return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue()); - return loopUnrollJamByFactor(forInst, unrollJamFactor); + return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue()); + return loopUnrollJamByFactor(forOp, unrollJamFactor); } /// Unrolls and jams this loop by the specified factor. -bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { +bool mlir::loopUnrollJamByFactor(OpPointer forOp, + uint64_t unrollJamFactor) { // Gathers all maximal sub-blocks of instructions that do not themselves // include a for inst (a instruction could have a descendant for inst though // in its tree). class JamBlockGatherer : public InstWalker { public: using InstListType = llvm::iplist; + using InstWalker::walk; // Store iterators to the first and last inst of each sub-block found. std::vector> subBlocks; @@ -144,30 +149,30 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; - while (it != End && !isa(it)) + while (it != End && !cast(it)->isa()) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); // Process all for insts that appear next. - while (it != End && isa(it)) - walkForInst(cast(it++)); + while (it != End && cast(it)->isa()) + walk(&*it++); } } }; assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); - if (unrollJamFactor == 1 || forInst->getBody()->empty()) + if (unrollJamFactor == 1 || forOp->getBody()->empty()) return false; - Optional mayBeConstantTripCount = getConstantTripCount(*forInst); + Optional mayBeConstantTripCount = getConstantTripCount(forOp); if (!mayBeConstantTripCount.hasValue() && - getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0) + getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) return false; - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); + auto ubMap = forOp->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -178,7 +183,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different sets of operands. - if (!forInst->matchingBoundOperandList()) + if (!forOp->matchingBoundOperandList()) return false; // If the trip count is lower than the unroll jam factor, no unroll jam. @@ -187,35 +192,38 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { mayBeConstantTripCount.getValue() < unrollJamFactor) return false; + auto *forInst = forOp->getInstruction(); + // Gather all sub-blocks to jam upon the loop being unrolled. JamBlockGatherer jbg; - jbg.walkForInst(forInst); + jbg.walkOpInst(forInst); auto &subBlocks = jbg.subBlocks; // Generate the cleanup loop if trip count isn't a multiple of // unrollJamFactor. if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() % unrollJamFactor != 0) { - // Insert the cleanup loop right after 'forInst'. + // Insert the cleanup loop right after 'forOp'. FuncBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); - auto *cleanupForInst = cast(builder.clone(*forInst)); - cleanupForInst->setLowerBoundMap( - getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder)); + auto cleanupAffineForOp = + cast(builder.clone(*forInst))->cast(); + cleanupAffineForOp->setLowerBoundMap( + getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder)); // The upper bound needs to be adjusted. - forInst->setUpperBoundMap( - getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder)); + forOp->setUpperBoundMap( + getUnrolledLoopUpperBound(forOp, unrollJamFactor, &builder)); // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(cleanupForInst); + promoteIfSingleIteration(cleanupAffineForOp); } // Scale the step of loop being unroll-jammed by the unroll-jam factor. - int64_t step = forInst->getStep(); - forInst->setStep(step * unrollJamFactor); + int64_t step = forOp->getStep(); + forOp->setStep(step * unrollJamFactor); - auto *forInstIV = forInst->getInductionVar(); + auto *forOpIV = forOp->getInductionVar(); for (auto &subBlock : subBlocks) { // Builder to insert unroll-jammed bodies. Insert right at the end of // sub-block. @@ -227,13 +235,13 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forInstIV->use_empty()) { + if (!forOpIV->use_empty()) { // iv' = iv + i, i = 1 to unrollJamFactor-1. auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); - auto ivUnroll = builder.create(forInst->getLoc(), - bumpMap, forInstIV); - operandMapping.map(forInstIV, ivUnroll); + auto ivUnroll = + builder.create(forInst->getLoc(), bumpMap, forOpIV); + operandMapping.map(forOpIV, ivUnroll); } // Clone the sub-block being unroll-jammed. for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { @@ -243,7 +251,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forInst); + promoteIfSingleIteration(forOp); return true; } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index f770684f5198..24ca4e950822 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" @@ -246,7 +247,7 @@ public: LowerAffinePass() : FunctionPass(&passID) {} PassResult runOnFunction(Function *function) override; - bool lowerForInst(ForInst *forInst); + bool lowerAffineFor(OpPointer forOp); bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); @@ -295,11 +296,11 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // a nested loop). Induction variable modification is appended to the body SESE // region that always loops back to the condition block. // -// +--------------------------------+ -// | | -// | | -// | br cond(%iv) | -// +--------------------------------+ +// +---------------------------------+ +// | | +// | | +// | br cond(%iv) | +// +---------------------------------+ // | // -------| | // | v v @@ -322,11 +323,12 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // v // +--------------------------------+ // | end: | -// | | +// | | // +--------------------------------+ // -bool LowerAffinePass::lowerForInst(ForInst *forInst) { - auto loc = forInst->getLoc(); +bool LowerAffinePass::lowerAffineFor(OpPointer forOp) { + auto loc = forOp->getLoc(); + auto *forInst = forOp->getInstruction(); // Start by splitting the block containing the 'for' into two parts. The part // before will get the init code, the part after will be the end point. @@ -339,23 +341,23 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { conditionBlock->insertBefore(endBlock); auto *iv = conditionBlock->addArgument(IndexType::get(forInst->getContext())); - // Create the body block, moving the body of the forInst over to it. + // Create the body block, moving the body of the forOp over to it. auto *bodyBlock = new Block(); bodyBlock->insertBefore(endBlock); - auto *oldBody = forInst->getBody(); + auto *oldBody = forOp->getBody(); bodyBlock->getInstructions().splice(bodyBlock->begin(), oldBody->getInstructions(), oldBody->begin(), oldBody->end()); - // The code in the body of the forInst now uses 'iv' as its indvar. - forInst->getInductionVar()->replaceAllUsesWith(iv); + // The code in the body of the forOp now uses 'iv' as its indvar. + forOp->getInductionVar()->replaceAllUsesWith(iv); // Append the induction variable stepping logic and branch back to the exit // condition block. Construct an affine expression f : (x -> x+step) and // apply this expression to the induction variable. FuncBuilder builder(bodyBlock); - auto affStep = builder.getAffineConstantExpr(forInst->getStep()); + auto affStep = builder.getAffineConstantExpr(forOp->getStep()); auto affDim = builder.getAffineDimExpr(0); auto stepped = expandAffineExpr(&builder, loc, affDim + affStep, iv, {}); if (!stepped) @@ -368,18 +370,18 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { builder.setInsertionPointToEnd(initBlock); // Compute loop bounds. - SmallVector operands(forInst->getLowerBoundOperands()); + SmallVector operands(forOp->getLowerBoundOperands()); auto lbValues = expandAffineMap(&builder, forInst->getLoc(), - forInst->getLowerBoundMap(), operands); + forOp->getLowerBoundMap(), operands); if (!lbValues) return true; Value *lowerBound = buildMinMaxReductionSeq(loc, CmpIPredicate::SGT, *lbValues, builder); - operands.assign(forInst->getUpperBoundOperands().begin(), - forInst->getUpperBoundOperands().end()); + operands.assign(forOp->getUpperBoundOperands().begin(), + forOp->getUpperBoundOperands().end()); auto ubValues = expandAffineMap(&builder, forInst->getLoc(), - forInst->getUpperBoundMap(), operands); + forOp->getUpperBoundMap(), operands); if (!ubValues) return true; Value *upperBound = @@ -394,7 +396,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { endBlock, ArrayRef()); // Ok, we're done! - forInst->erase(); + forOp->erase(); return false; } @@ -614,28 +616,26 @@ PassResult LowerAffinePass::runOnFunction(Function *function) { // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa(inst)) - instsToRewrite.push_back(inst); - auto op = dyn_cast(inst); - if (op && (op->isa() || op->isa())) + auto op = cast(inst); + if (op->isa() || op->isa() || + op->isa()) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. - for (auto *inst : instsToRewrite) - if (auto *forInst = dyn_cast(inst)) { - if (lowerForInst(forInst)) + for (auto *inst : instsToRewrite) { + auto op = cast(inst); + if (auto ifOp = op->dyn_cast()) { + if (lowerAffineIf(ifOp)) return failure(); - } else { - auto op = cast(inst); - if (auto ifOp = op->dyn_cast()) { - if (lowerAffineIf(ifOp)) - return failure(); - } else if (lowerAffineApply(op->cast())) { + } else if (auto forOp = op->dyn_cast()) { + if (lowerAffineFor(forOp)) return failure(); - } + } else if (lowerAffineApply(op->cast())) { + return failure(); } + } return success(); } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 432ad1f39b8d..f2dae11112b5 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -75,7 +75,7 @@ /// Implementation details /// ====================== /// The current decisions made by the super-vectorization pass guarantee that -/// use-def chains do not escape an enclosing vectorized ForInst. In other +/// use-def chains do not escape an enclosing vectorized AffineForOp. In other /// words, this pass operates on a scoped program slice. Furthermore, since we /// do not vectorize in the presence of conditionals for now, sliced chains are /// guaranteed not to escape the innermost scope, which has to be either the top @@ -285,13 +285,12 @@ static Value *substitute(Value *v, VectorType hwVectorType, /// /// The general problem this function solves is as follows: /// Assume a vector_transfer operation at the super-vector granularity that has -/// `l` enclosing loops (ForInst). Assume the vector transfer operation operates -/// on a MemRef of rank `r`, a super-vector of rank `s` and a hardware vector of -/// rank `h`. -/// For the purpose of illustration assume l==4, r==3, s==2, h==1 and that the -/// super-vector is vector<3x32xf32> and the hardware vector is vector<8xf32>. -/// Assume the following MLIR snippet after super-vectorization has been -/// applied: +/// `l` enclosing loops (AffineForOp). Assume the vector transfer operation +/// operates on a MemRef of rank `r`, a super-vector of rank `s` and a hardware +/// vector of rank `h`. For the purpose of illustration assume l==4, r==3, s==2, +/// h==1 and that the super-vector is vector<3x32xf32> and the hardware vector +/// is vector<8xf32>. Assume the following MLIR snippet after +/// super-vectorization has been applied: /// /// ```mlir /// for %i0 = 0 to %M { @@ -351,7 +350,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, SmallVector affineExprs; // TODO(ntv): support a concrete map and composition. unsigned i = 0; - // The first numMemRefIndices correspond to ForInst that have not been + // The first numMemRefIndices correspond to AffineForOp that have not been // vectorized, the transformation is the identity on those. for (i = 0; i < numMemRefIndices; ++i) { auto d_i = b->getAffineDimExpr(i); @@ -554,9 +553,6 @@ static bool instantiateMaterialization(Instruction *inst, MaterializationState *state) { LLVM_DEBUG(dbgs() << "\ninstantiate: " << *inst); - if (isa(inst)) - return inst->emitError("NYI path ForInst"); - // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); auto *opInst = cast(inst); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 811741d08d1b..2e083bbfd79f 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -21,11 +21,11 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -38,15 +38,12 @@ using namespace mlir; namespace { -struct PipelineDataTransfer : public FunctionPass, - InstWalker { +struct PipelineDataTransfer : public FunctionPass { PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} PassResult runOnFunction(Function *f) override; - PassResult runOnForInst(ForInst *forInst); + PassResult runOnAffineForOp(OpPointer forOp); - // Collect all 'for' instructions. - void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - std::vector forInsts; + std::vector> forOps; static char passID; }; @@ -79,8 +76,8 @@ static unsigned getTagMemRefPos(const OperationInst &dmaInst) { /// of the old memref by the new one while indexing the newly added dimension by /// the loop IV of the specified 'for' instruction modulo 2. Returns false if /// such a replacement cannot be performed. -static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { - auto *forBody = forInst->getBody(); +static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { + auto *forBody = forOp->getBody(); FuncBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); @@ -101,6 +98,7 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { auto newMemRefType = doubleShape(oldMemRefType); // Put together alloc operands for the dynamic dimensions of the memref. + auto *forInst = forOp->getInstruction(); FuncBuilder bOuter(forInst); SmallVector allocOperands; unsigned dynamicDimCount = 0; @@ -118,16 +116,16 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); - int64_t step = forInst->getStep(); + int64_t step = forOp->getStep(); auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0.floorDiv(step) % 2}, {}); - auto ivModTwoOp = bInner.create(forInst->getLoc(), modTwoMap, - forInst->getInductionVar()); + auto ivModTwoOp = bInner.create(forOp->getLoc(), modTwoMap, + forOp->getInductionVar()); - // replaceAllMemRefUsesWith will always succeed unless the forInst body has + // replaceAllMemRefUsesWith will always succeed unless the forOp body has // non-deferencing uses of the memref. if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, {ivModTwoOp}, AffineMap(), - {}, &*forInst->getBody()->begin())) { + {}, &*forOp->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); ivModTwoOp->getInstruction()->erase(); @@ -143,11 +141,14 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) { // invalid (erased) when the outer loop is pipelined (the pipelined one gets // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). - forInsts.clear(); - walkPostOrder(f); + forOps.clear(); + f->walkOpsPostOrder([&](OperationInst *opInst) { + if (auto forOp = opInst->dyn_cast()) + forOps.push_back(forOp); + }); bool ret = false; - for (auto *forInst : forInsts) { - ret = ret | runOnForInst(forInst); + for (auto forOp : forOps) { + ret = ret | runOnAffineForOp(forOp); } return ret ? failure() : success(); } @@ -178,13 +179,13 @@ static bool checkTagMatch(OpPointer startOp, // Identify matching DMA start/finish instructions to overlap computation with. static void findMatchingStartFinishInsts( - ForInst *forInst, + OpPointer forOp, SmallVectorImpl> &startWaitPairs) { // Collect outgoing DMA instructions - needed to check for dependences below. SmallVector, 4> outgoingDmaOps; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { auto *opInst = dyn_cast(&inst); if (!opInst) continue; @@ -195,7 +196,7 @@ static void findMatchingStartFinishInsts( } SmallVector dmaStartInsts, dmaFinishInsts; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { auto *opInst = dyn_cast(&inst); if (!opInst) continue; @@ -227,7 +228,7 @@ static void findMatchingStartFinishInsts( auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()); bool escapingUses = false; for (const auto &use : memref->getUses()) { - if (!forInst->getBody()->findAncestorInstInBlock(*use.getOwner())) { + if (!forOp->getBody()->findAncestorInstInBlock(*use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -251,17 +252,18 @@ static void findMatchingStartFinishInsts( } /// Overlap DMA transfers with computation in this loop. If successful, -/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are +/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are /// inserted right before where it was. -PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { - auto mayBeConstTripCount = getConstantTripCount(*forInst); +PassResult +PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { + auto mayBeConstTripCount = getConstantTripCount(forOp); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n"); return success(); } SmallVector, 4> startWaitPairs; - findMatchingStartFinishInsts(forInst, startWaitPairs); + findMatchingStartFinishInsts(forOp, startWaitPairs); if (startWaitPairs.empty()) { LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";); @@ -280,7 +282,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { auto *dmaStartInst = pair.first; Value *oldMemRef = dmaStartInst->getOperand( dmaStartInst->cast()->getFasterMemPos()); - if (!doubleBuffer(oldMemRef, forInst)) { + if (!doubleBuffer(oldMemRef, forOp)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); @@ -302,7 +304,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { auto *dmaFinishInst = pair.second; Value *oldTagMemRef = dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); - if (!doubleBuffer(oldTagMemRef, forInst)) { + if (!doubleBuffer(oldTagMemRef, forOp)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); return success(); } @@ -315,7 +317,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { // Double buffering would have invalidated all the old DMA start/wait insts. startWaitPairs.clear(); - findMatchingStartFinishInsts(forInst, startWaitPairs); + findMatchingStartFinishInsts(forOp, startWaitPairs); // Store shift for instruction for later lookup for AffineApplyOp's. DenseMap instShiftMap; @@ -342,16 +344,16 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { } } // Everything else (including compute ops and dma finish) are shifted by one. - for (const auto &inst : *forInst->getBody()) { + for (const auto &inst : *forOp->getBody()) { if (instShiftMap.find(&inst) == instShiftMap.end()) { instShiftMap[&inst] = 1; } } // Get shifts stored in map. - std::vector shifts(forInst->getBody()->getInstructions().size()); + std::vector shifts(forOp->getBody()->getInstructions().size()); unsigned s = 0; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { assert(instShiftMap.find(&inst) != instShiftMap.end()); shifts[s++] = instShiftMap[&inst]; LLVM_DEBUG( @@ -363,13 +365,13 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { }); } - if (!isInstwiseShiftValid(*forInst, shifts)) { + if (!isInstwiseShiftValid(forOp, shifts)) { // Violates dependences. LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); return success(); } - if (instBodySkew(forInst, shifts)) { + if (instBodySkew(forOp, shifts)) { LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";); return success(); } diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index ba59123c7004..ae003b3e4953 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -22,6 +22,7 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Function.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/Pass.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 59da2b0a56e9..ce16656243da 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -21,6 +21,7 @@ #include "mlir/Transforms/LoopUtils.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -39,22 +40,22 @@ using namespace mlir; /// Returns the upper bound of an unrolled loop with lower bound 'lb' and with /// the specified trip count, stride, and unroll factor. Returns nullptr when /// the trip count can't be expressed as an affine expression. -AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst, +AffineMap mlir::getUnrolledLoopUpperBound(ConstOpPointer forOp, unsigned unrollFactor, FuncBuilder *builder) { - auto lbMap = forInst.getLowerBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); // Single result lower bound map only. if (lbMap.getNumResults() != 1) return AffineMap(); // Sometimes, the trip count cannot be expressed as an affine expression. - auto tripCount = getTripCountExpr(forInst); + auto tripCount = getTripCountExpr(forOp); if (!tripCount) return AffineMap(); AffineExpr lb(lbMap.getResult(0)); - unsigned step = forInst.getStep(); + unsigned step = forOp->getStep(); auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step; return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), @@ -65,50 +66,51 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst, /// bound 'lb' and with the specified trip count, stride, and unroll factor. /// Returns an AffinMap with nullptr storage (that evaluates to false) /// when the trip count can't be expressed as an affine expression. -AffineMap mlir::getCleanupLoopLowerBound(const ForInst &forInst, +AffineMap mlir::getCleanupLoopLowerBound(ConstOpPointer forOp, unsigned unrollFactor, FuncBuilder *builder) { - auto lbMap = forInst.getLowerBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); // Single result lower bound map only. if (lbMap.getNumResults() != 1) return AffineMap(); // Sometimes the trip count cannot be expressed as an affine expression. - AffineExpr tripCount(getTripCountExpr(forInst)); + AffineExpr tripCount(getTripCountExpr(forOp)); if (!tripCount) return AffineMap(); AffineExpr lb(lbMap.getResult(0)); - unsigned step = forInst.getStep(); + unsigned step = forOp->getStep(); auto newLb = lb + (tripCount - tripCount % unrollFactor) * step; return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), {newLb}, {}); } -/// Promotes the loop body of a forInst to its containing block if the forInst +/// Promotes the loop body of a forOp to its containing block if the forOp /// was known to have a single iteration. Returns false otherwise. // TODO(bondhugula): extend this for arbitrary affine bounds. -bool mlir::promoteIfSingleIteration(ForInst *forInst) { - Optional tripCount = getConstantTripCount(*forInst); +bool mlir::promoteIfSingleIteration(OpPointer forOp) { + Optional tripCount = getConstantTripCount(forOp); if (!tripCount.hasValue() || tripCount.getValue() != 1) return false; // TODO(mlir-team): there is no builder for a max. - if (forInst->getLowerBoundMap().getNumResults() != 1) + if (forOp->getLowerBoundMap().getNumResults() != 1) return false; // Replaces all IV uses to its single iteration value. - auto *iv = forInst->getInductionVar(); + auto *iv = forOp->getInductionVar(); + OperationInst *forInst = forOp->getInstruction(); if (!iv->use_empty()) { - if (forInst->hasConstantLowerBound()) { + if (forOp->hasConstantLowerBound()) { auto *mlFunc = forInst->getFunction(); FuncBuilder topBuilder(mlFunc); auto constOp = topBuilder.create( - forInst->getLoc(), forInst->getConstantLowerBound()); + forOp->getLoc(), forOp->getConstantLowerBound()); iv->replaceAllUsesWith(constOp); } else { - const AffineBound lb = forInst->getLowerBound(); + const AffineBound lb = forOp->getLowerBound(); SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst)); if (lb.getMap() == builder.getDimIdentityMap()) { @@ -124,8 +126,8 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) { // Move the loop body instructions to the loop's containing block. auto *block = forInst->getBlock(); block->getInstructions().splice(Block::iterator(forInst), - forInst->getBody()->getInstructions()); - forInst->erase(); + forOp->getBody()->getInstructions()); + forOp->erase(); return true; } @@ -133,13 +135,10 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) { /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - class LoopBodyPromoter : public InstWalker { - public: - void visitForInst(ForInst *forInst) { promoteIfSingleIteration(forInst); } - }; - - LoopBodyPromoter fsw; - fsw.walkPostOrder(f); + f->walkOpsPostOrder([](OperationInst *inst) { + if (auto forOp = inst->dyn_cast()) + promoteIfSingleIteration(forOp); + }); } /// Generates a 'for' inst with the specified lower and upper bounds while @@ -149,19 +148,22 @@ void mlir::promoteSingleIterationLoops(Function *f) { /// the pair specifies the shift applied to that group of instructions; note /// that the shift is multiplied by the loop step before being applied. Returns /// nullptr if the generated loop simplifies to a single iteration one. -static ForInst * +static OpPointer generateLoop(AffineMap lbMap, AffineMap ubMap, const std::vector>> &instGroupQueue, - unsigned offset, ForInst *srcForInst, FuncBuilder *b) { + unsigned offset, OpPointer srcForInst, + FuncBuilder *b) { SmallVector lbOperands(srcForInst->getLowerBoundOperands()); SmallVector ubOperands(srcForInst->getUpperBoundOperands()); assert(lbMap.getNumInputs() == lbOperands.size()); assert(ubMap.getNumInputs() == ubOperands.size()); - auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap, - ubOperands, ubMap, srcForInst->getStep()); + auto loopChunk = + b->create(srcForInst->getLoc(), lbOperands, lbMap, + ubOperands, ubMap, srcForInst->getStep()); + loopChunk->createBody(); auto *loopChunkIV = loopChunk->getInductionVar(); auto *srcIV = srcForInst->getInductionVar(); @@ -176,7 +178,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. if (!srcIV->use_empty() && shift != 0) { - auto b = FuncBuilder::getForInstBodyBuilder(loopChunk); + FuncBuilder b(loopChunk->getBody()); auto ivRemap = b.create( srcForInst->getLoc(), b.getSingleDimShiftAffineMap( @@ -191,7 +193,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, } } if (promoteIfSingleIteration(loopChunk)) - return nullptr; + return OpPointer(); return loopChunk; } @@ -210,28 +212,29 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // asserts preservation of SSA dominance. A check for that as well as that for // memory-based depedence preservation check rests with the users of this // method. -UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef shifts, +UtilResult mlir::instBodySkew(OpPointer forOp, + ArrayRef shifts, bool unrollPrologueEpilogue) { - if (forInst->getBody()->empty()) + if (forOp->getBody()->empty()) return UtilResult::Success; // If the trip counts aren't constant, we would need versioning and // conditional guards (or context information to prevent such versioning). The // better way to pipeline for such loops is to first tile them and extract // constant trip count "full tiles" before applying this. - auto mayBeConstTripCount = getConstantTripCount(*forInst); + auto mayBeConstTripCount = getConstantTripCount(forOp); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";); return UtilResult::Success; } uint64_t tripCount = mayBeConstTripCount.getValue(); - assert(isInstwiseShiftValid(*forInst, shifts) && + assert(isInstwiseShiftValid(forOp, shifts) && "shifts will lead to an invalid transformation\n"); - int64_t step = forInst->getStep(); + int64_t step = forOp->getStep(); - unsigned numChildInsts = forInst->getBody()->getInstructions().size(); + unsigned numChildInsts = forOp->getBody()->getInstructions().size(); // Do a linear time (counting) sort for the shifts. uint64_t maxShift = 0; @@ -249,7 +252,7 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef shifts, // body of the 'for' inst. std::vector> sortedInstGroups(maxShift + 1); unsigned pos = 0; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { auto shift = shifts[pos++]; sortedInstGroups[shift].push_back(&inst); } @@ -259,17 +262,17 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef shifts, // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first // loop generated as the prologue and the last as epilogue and unroll these // fully. - ForInst *prologue = nullptr; - ForInst *epilogue = nullptr; + OpPointer prologue; + OpPointer epilogue; // Do a sweep over the sorted shifts while storing open groups in a // vector, and generating loop portions as necessary during the sweep. A block // of instructions is paired with its shift. std::vector>> instGroupQueue; - auto origLbMap = forInst->getLowerBoundMap(); + auto origLbMap = forOp->getLowerBoundMap(); uint64_t lbShift = 0; - FuncBuilder b(forInst); + FuncBuilder b(forOp->getInstruction()); for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) { // If nothing is shifted by d, continue. if (sortedInstGroups[d].empty()) @@ -280,19 +283,19 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef shifts, // The interval for which the loop needs to be generated here is: // [lbShift, min(lbShift + tripCount, d)) and the body of the // loop needs to have all instructions in instQueue in that order. - ForInst *res; + OpPointer res; if (lbShift + tripCount * step < d * step) { res = generateLoop( b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step), - instGroupQueue, 0, forInst, &b); + instGroupQueue, 0, forOp, &b); // Entire loop for the queued inst groups generated, empty it. instGroupQueue.clear(); lbShift += tripCount * step; } else { res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, d), instGroupQueue, - 0, forInst, &b); + 0, forOp, &b); lbShift = d * step; } if (!prologue && res) @@ -312,60 +315,63 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef shifts, uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step; epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, ubShift), - instGroupQueue, i, forInst, &b); + instGroupQueue, i, forOp, &b); lbShift = ubShift; if (!prologue) prologue = epilogue; } // Erase the original for inst. - forInst->erase(); + forOp->erase(); if (unrollPrologueEpilogue && prologue) loopUnrollFull(prologue); - if (unrollPrologueEpilogue && !epilogue && epilogue != prologue) + if (unrollPrologueEpilogue && !epilogue && + epilogue->getInstruction() != prologue->getInstruction()) loopUnrollFull(epilogue); return UtilResult::Success; } /// Unrolls this loop completely. -bool mlir::loopUnrollFull(ForInst *forInst) { - Optional mayBeConstantTripCount = getConstantTripCount(*forInst); +bool mlir::loopUnrollFull(OpPointer forOp) { + Optional mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue()) { uint64_t tripCount = mayBeConstantTripCount.getValue(); if (tripCount == 1) { - return promoteIfSingleIteration(forInst); + return promoteIfSingleIteration(forOp); } - return loopUnrollByFactor(forInst, tripCount); + return loopUnrollByFactor(forOp, tripCount); } return false; } /// Unrolls and jams this loop by the specified factor or by the trip count (if /// constant) whichever is lower. -bool mlir::loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor) { - Optional mayBeConstantTripCount = getConstantTripCount(*forInst); +bool mlir::loopUnrollUpToFactor(OpPointer forOp, + uint64_t unrollFactor) { + Optional mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollFactor) - return loopUnrollByFactor(forInst, mayBeConstantTripCount.getValue()); - return loopUnrollByFactor(forInst, unrollFactor); + return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue()); + return loopUnrollByFactor(forOp, unrollFactor); } /// Unrolls this loop by the specified factor. Returns true if the loop /// is successfully unrolled. -bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { +bool mlir::loopUnrollByFactor(OpPointer forOp, + uint64_t unrollFactor) { assert(unrollFactor >= 1 && "unroll factor should be >= 1"); if (unrollFactor == 1) - return promoteIfSingleIteration(forInst); + return promoteIfSingleIteration(forOp); - if (forInst->getBody()->empty()) + if (forOp->getBody()->empty()) return false; - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); + auto ubMap = forOp->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -376,10 +382,10 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different operand lists. - if (!forInst->matchingBoundOperandList()) + if (!forOp->matchingBoundOperandList()) return false; - Optional mayBeConstantTripCount = getConstantTripCount(*forInst); + Optional mayBeConstantTripCount = getConstantTripCount(forOp); // If the trip count is lower than the unroll factor, no unrolled body. // TODO(bondhugula): option to specify cleanup loop unrolling. @@ -388,10 +394,12 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { return false; // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. - if (getLargestDivisorOfTripCount(*forInst) % unrollFactor != 0) { + OperationInst *forInst = forOp->getInstruction(); + if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst)); - auto *cleanupForInst = cast(builder.clone(*forInst)); - auto clLbMap = getCleanupLoopLowerBound(*forInst, unrollFactor, &builder); + auto cleanupForInst = + cast(builder.clone(*forInst))->cast(); + auto clLbMap = getCleanupLoopLowerBound(forOp, unrollFactor, &builder); assert(clLbMap && "cleanup loop lower bound map for single result bound maps can " "always be determined"); @@ -401,50 +409,50 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { // Adjust upper bound. auto unrolledUbMap = - getUnrolledLoopUpperBound(*forInst, unrollFactor, &builder); + getUnrolledLoopUpperBound(forOp, unrollFactor, &builder); assert(unrolledUbMap && "upper bound map can alwayys be determined for an unrolled loop " "with single result bounds"); - forInst->setUpperBoundMap(unrolledUbMap); + forOp->setUpperBoundMap(unrolledUbMap); } // Scale the step of loop being unrolled by unroll factor. - int64_t step = forInst->getStep(); - forInst->setStep(step * unrollFactor); + int64_t step = forOp->getStep(); + forOp->setStep(step * unrollFactor); // Builder to insert unrolled bodies right after the last instruction in the - // body of 'forInst'. - FuncBuilder builder(forInst->getBody(), forInst->getBody()->end()); + // body of 'forOp'. + FuncBuilder builder(forOp->getBody(), forOp->getBody()->end()); // Keep a pointer to the last instruction in the original block so that we // know what to clone (since we are doing this in-place). - Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end()); + Block::iterator srcBlockEnd = std::prev(forOp->getBody()->end()); - // Unroll the contents of 'forInst' (append unrollFactor-1 additional copies). - auto *forInstIV = forInst->getInductionVar(); + // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies). + auto *forOpIV = forOp->getInductionVar(); for (unsigned i = 1; i < unrollFactor; i++) { BlockAndValueMapping operandMap; // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forInstIV->use_empty()) { + if (!forOpIV->use_empty()) { // iv' = iv + 1/2/3...unrollFactor-1; auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); auto ivUnroll = - builder.create(forInst->getLoc(), bumpMap, forInstIV); - operandMap.map(forInstIV, ivUnroll); + builder.create(forOp->getLoc(), bumpMap, forOpIV); + operandMap.map(forOpIV, ivUnroll); } - // Clone the original body of 'forInst'. - for (auto it = forInst->getBody()->begin(); it != std::next(srcBlockEnd); + // Clone the original body of 'forOp'. + for (auto it = forOp->getBody()->begin(); it != std::next(srcBlockEnd); it++) { builder.clone(*it, operandMap); } } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forInst); + promoteIfSingleIteration(forOp); return true; } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index d3689d056d66..819f1a59b6fc 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/Utils.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Dominance.h" @@ -278,8 +279,8 @@ void mlir::createAffineComputationSlice( /// Folds the specified (lower or upper) bound to a constant if possible /// considering its operands. Returns false if the folding happens for any of /// the bounds, true otherwise. -bool mlir::constantFoldBounds(ForInst *forInst) { - auto foldLowerOrUpperBound = [forInst](bool lower) { +bool mlir::constantFoldBounds(OpPointer forInst) { + auto foldLowerOrUpperBound = [&forInst](bool lower) { // Check if the bound is already a constant. if (lower && forInst->hasConstantLowerBound()) return true; diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index ac551d7c20c2..7f26161e5201 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/VectorAnalysis.h" @@ -252,9 +253,9 @@ using namespace mlir; /// ========== /// The algorithm proceeds in a few steps: /// 1. defining super-vectorization patterns and matching them on the tree of -/// ForInst. A super-vectorization pattern is defined as a recursive data -/// structures that matches and captures nested, imperfectly-nested loops -/// that have a. comformable loop annotations attached (e.g. parallel, +/// AffineForOp. A super-vectorization pattern is defined as a recursive +/// data structures that matches and captures nested, imperfectly-nested +/// loops that have a. comformable loop annotations attached (e.g. parallel, /// reduction, vectoriable, ...) as well as b. all contiguous load/store /// operations along a specified minor dimension (not necessarily the /// fastest varying) ; @@ -279,11 +280,11 @@ using namespace mlir; /// it by its vector form. Otherwise, if the scalar value is a constant, /// it is vectorized into a splat. In all other cases, vectorization for /// the pattern currently fails. -/// e. if everything under the root ForInst in the current pattern vectorizes -/// properly, we commit that loop to the IR. Otherwise we discard it and -/// restore a previously cloned version of the loop. Thanks to the -/// recursive scoping nature of matchers and captured patterns, this is -/// transparently achieved by a simple RAII implementation. +/// e. if everything under the root AffineForOp in the current pattern +/// vectorizes properly, we commit that loop to the IR. Otherwise we +/// discard it and restore a previously cloned version of the loop. Thanks +/// to the recursive scoping nature of matchers and captured patterns, +/// this is transparently achieved by a simple RAII implementation. /// f. vectorization is applied on the next pattern in the list. Because /// pattern interference avoidance is not yet implemented and that we do /// not support further vectorizing an already vector load we need to @@ -667,12 +668,13 @@ namespace { struct VectorizationStrategy { SmallVector vectorSizes; - DenseMap loopToVectorDim; + DenseMap loopToVectorDim; }; } // end anonymous namespace -static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern, +static void vectorizeLoopIfProfitable(Instruction *loop, + unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { assert(patternDepth > depthInPattern && @@ -704,13 +706,13 @@ static bool analyzeProfitability(ArrayRef matches, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast(m.getMatchedInstruction()); bool fail = analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1, patternDepth, strategy); if (fail) { return fail; } - vectorizeLoopIfProfitable(loop, depthInPattern, patternDepth, strategy); + vectorizeLoopIfProfitable(m.getMatchedInstruction(), depthInPattern, + patternDepth, strategy); } return false; } @@ -855,8 +857,8 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, /// Coarsens the loops bounds and transforms all remaining load and store /// operations into the appropriate vector_transfer. -static bool vectorizeForInst(ForInst *loop, int64_t step, - VectorizationState *state) { +static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step, + VectorizationState *state) { using namespace functional; loop->setStep(step); @@ -873,7 +875,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step, }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); SmallVector loadAndStoresMatches; - loadAndStores.match(loop, &loadAndStoresMatches); + loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { auto *opInst = cast(ls.getMatchedInstruction()); auto load = opInst->dyn_cast(); @@ -898,7 +900,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step, static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { return [fastestVaryingMemRefDimension](const Instruction &forInst) { - const auto &loop = cast(forInst); + auto loop = cast(forInst).cast(); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); }; @@ -912,7 +914,8 @@ static bool vectorizeNonRoot(ArrayRef matches, /// if all vectorizations in `childrenMatches` have already succeeded /// recursively in DFS post-order. static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { - ForInst *loop = cast(oneMatch.getMatchedInstruction()); + auto *loopInst = oneMatch.getMatchedInstruction(); + auto loop = cast(loopInst)->cast(); auto childrenMatches = oneMatch.getMatchedChildren(); // 1. DFS postorder recursion, if any of my children fails, I fail too. @@ -924,7 +927,7 @@ static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { // 2. This loop may have been omitted from vectorization for various reasons // (e.g. due to the performance model or pattern depth > vector size). - auto it = state->strategy->loopToVectorDim.find(loop); + auto it = state->strategy->loopToVectorDim.find(loopInst); if (it == state->strategy->loopToVectorDim.end()) { return false; } @@ -939,10 +942,10 @@ static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { // exploratory tradeoffs (see top of the file). Apply coarsening, i.e.: // | ub -> ub // | step -> step * vectorSize - LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForInst by " << vectorSize + LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize << " : "); - LLVM_DEBUG(loop->print(dbgs())); - return vectorizeForInst(loop, loop->getStep() * vectorSize, state); + LLVM_DEBUG(loopInst->print(dbgs())); + return vectorizeAffineForOp(loop, loop->getStep() * vectorSize, state); } /// Non-root pattern iterates over the matches at this level, calls doVectorize @@ -1186,7 +1189,8 @@ static bool vectorizeOperations(VectorizationState *state) { /// Each root may succeed independently but will otherwise clean after itself if /// anything below it fails. static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { - auto *loop = cast(m.getMatchedInstruction()); + auto loop = + cast(m.getMatchedInstruction())->cast(); VectorizationState state; state.strategy = strategy; @@ -1197,17 +1201,20 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { // vectorizable. If a pattern is not vectorizable anymore, we just skip it. // TODO(ntv): implement a non-greedy profitability analysis that keeps only // non-intersecting patterns. - if (!isVectorizableLoop(*loop)) { + if (!isVectorizableLoop(loop)) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable"); return true; } - FuncBuilder builder(loop); // builder to insert in place of loop - ForInst *clonedLoop = cast(builder.clone(*loop)); + auto *loopInst = loop->getInstruction(); + FuncBuilder builder(loopInst); + auto clonedLoop = + cast(builder.clone(*loopInst))->cast(); + auto fail = doVectorize(m, &state); /// Sets up error handling for this root loop. This is how the root match /// maintains a clone for handling failure and restores the proper state via /// RAII. - ScopeGuard sg2([&fail, loop, clonedLoop]() { + ScopeGuard sg2([&fail, &loop, &clonedLoop]() { if (fail) { loop->getInductionVar()->replaceAllUsesWith( clonedLoop->getInductionVar()); @@ -1291,8 +1298,8 @@ PassResult Vectorize::runOnFunction(Function *f) { if (fail) { continue; } - auto *loop = cast(m.getMatchedInstruction()); - vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy); + vectorizeLoopIfProfitable(m.getMatchedInstruction(), 0, patternDepth, + &strategy); // TODO(ntv): if pattern does not apply, report it; alter the // cost/benefit. fail = vectorizeRootMatch(m, &strategy); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 595991c01097..e41f88c901b4 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -204,7 +204,7 @@ func @illegaltype(i0) // expected-error {{invalid integer width}} // ----- func @malformed_for_percent() { - for i = 1 to 10 { // expected-error {{expected SSA identifier for the loop variable}} + for i = 1 to 10 { // expected-error {{expected SSA operand}} // ----- @@ -222,18 +222,18 @@ func @malformed_for_to() { func @incomplete_for() { for %i = 1 to 10 step 2 -} // expected-error {{expected '{' before instruction list}} +} // expected-error {{expected '{' to begin block list}} // ----- func @nonconstant_step(%1 : i32) { - for %2 = 1 to 5 step %1 { // expected-error {{expected integer}} + for %2 = 1 to 5 step %1 { // expected-error {{expected type}} // ----- func @for_negative_stride() { for %i = 1 to 10 step -1 -} // expected-error {{step has to be a positive integer}} +} // expected-error@-1 {{expected step to be representable as a positive signed integer}} // ----- @@ -510,7 +510,7 @@ func @undefined_function() { func @bound_symbol_mismatch(%N : index) { for %i = #map1(%N) to 100 { - // expected-error@-1 {{symbol operand count and affine map symbol count must match}} + // expected-error@-1 {{symbol operand count and integer set symbol count must match}} } return } @@ -521,78 +521,7 @@ func @bound_symbol_mismatch(%N : index) { func @bound_dim_mismatch(%N : index) { for %i = #map1(%N, %N)[%N] to 100 { - // expected-error@-1 {{dim operand count and affine map dim count must match}} - } - return -} - -// ----- - -#map1 = (i)[j] -> (i+j) - -func @invalid_dim_nested(%N : index) { - for %i = 1 to 100 { - %a = "foo"(%N) : (index)->(index) - for %j = 1 to #map1(%a)[%i] { - // expected-error@-1 {{value '%a' cannot be used as a dimension id}} - } - } - return -} - -// ----- - -#map1 = (i)[j] -> (i+j) - -func @invalid_dim_affine_apply(%N : index) { - for %i = 1 to 100 { - %a = "foo"(%N) : (index)->(index) - %w = affine_apply (i)->(i+1) (%a) - for %j = 1 to #map1(%w)[%i] { - // expected-error@-1 {{value '%w' cannot be used as a dimension id}} - } - } - return -} - -// ----- - -#map1 = (i)[j] -> (i+j) - -func @invalid_symbol_iv(%N : index) { - for %i = 1 to 100 { - %a = "foo"(%N) : (index)->(index) - for %j = 1 to #map1(%N)[%i] { - // expected-error@-1 {{value '%i' cannot be used as a symbol}} - } - } - return -} - -// ----- - -#map1 = (i)[j] -> (i+j) - -func @invalid_symbol_nested(%N : index) { - for %i = 1 to 100 { - %a = "foo"(%N) : (index)->(index) - for %j = 1 to #map1(%N)[%a] { - // expected-error@-1 {{value '%a' cannot be used as a symbol}} - } - } - return -} - -// ----- - -#map1 = (i)[j] -> (i+j) - -func @invalid_symbol_affine_apply(%N : index) { - for %i = 1 to 100 { - %w = affine_apply (i)->(i+1) (%i) - for %j = 1 to #map1(%i)[%w] { - // expected-error@-1 {{value '%w' cannot be used as a symbol}} - } + // expected-error@-1 {{dim operand count and integer set dim count must match}} } return } @@ -601,7 +530,7 @@ func @invalid_symbol_affine_apply(%N : index) { func @large_bound() { for %i = 1 to 9223372036854775810 { - // expected-error@-1 {{bound or step is too large for index}} + // expected-error@-1 {{integer constant out of range for attribute}} } return } @@ -609,7 +538,7 @@ func @large_bound() { // ----- func @max_in_upper_bound(%N : index) { - for %i = 1 to max (i)->(N, 100) { //expected-error {{expected SSA operand}} + for %i = 1 to max (i)->(N, 100) { //expected-error {{expected type}} } return } @@ -617,7 +546,7 @@ func @max_in_upper_bound(%N : index) { // ----- func @step_typo() { - for %i = 1 to 100 step -- 1 { //expected-error {{expected integer}} + for %i = 1 to 100 step -- 1 { //expected-error {{expected constant integer}} } return } diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index 8a90d12bd03c..7196e3a5c29c 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -12,9 +12,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { // CHECK: constant 4 : index loc(callsite("foo" at "mysource.cc":10:8)) %2 = constant 4 : index loc(callsite("foo" at "mysource.cc":10:8)) - // CHECK: for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) - for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { - } + // CHECK: } loc(fused["foo", "mysource.cc":10:8]) + for %i0 = 0 to 8 { + } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } loc(fused<"myPass">["foo", "foo2"]) if #set0(%2) { diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 626f24569c68..bee886c0f348 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -230,7 +230,7 @@ func @complex_loops() { func @triang_loop(%arg0: index, %arg1: memref) { %c = constant 0 : i32 // CHECK: %c0_i32 = constant 0 : i32 for %i0 = 1 to %arg0 { // CHECK: for %i0 = 1 to %arg0 { - for %i1 = %i0 to %arg0 { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to %arg0 { + for %i1 = (d0)[]->(d0)(%i0)[] to %arg0 { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to %arg0 { store %c, %arg1[%i0, %i1] : memref // CHECK: store %c0_i32, %arg1[%i0, %i1] } // CHECK: } } // CHECK: } @@ -254,7 +254,7 @@ func @loop_bounds(%N : index) { // CHECK: for %i0 = %0 to %arg0 for %i = %s to %N { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to 0 - for %j = %i to 0 step 1 { + for %j = (d0)[]->(d0)(%i)[] to 0 step 1 { // CHECK: %1 = affine_apply #map{{.*}}(%i0, %i1)[%0] %w1 = affine_apply(d0, d1)[s0] -> (d0+d1) (%i, %j) [%s] // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i1)[%0] @@ -764,23 +764,3 @@ func @verbose_if(%N: index) { } return } - -// CHECK-LABEL: func @verbose_for -func @verbose_for(%arg0 : index, %arg1 : index) { - // CHECK-NEXT: %0 = "for"() {lb: 1, ub: 10} : () -> index { - %a = "for"() {lb: 1, ub: 10 } : () -> index { - - // CHECK-NEXT: %1 = "for"() {lb: 1, step: 2, ub: 100} : () -> index { - %b = "for"() {lb: 1, ub: 100, step: 2 } : () -> index { - - // CHECK-NEXT: %2 = "for"(%arg0, %arg1) : (index, index) -> index { - %c = "for"(%arg0, %arg1) : (index, index) -> index { - - // CHECK-NEXT: %3 = "for"(%arg0) {ub: 100} : (index) -> index { - %d = "for"(%arg0) {ub: 100 } : (index) -> index { - } - } - } - } - return -} diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir index 69dace451654..4668e7a832b1 100644 --- a/mlir/test/IR/pretty-locations.mlir +++ b/mlir/test/IR/pretty-locations.mlir @@ -17,9 +17,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { // CHECK-NEXT: at mysource3.cc:100:10 %3 = constant 4 : index loc(callsite("foo" at callsite("mysource1.cc":10:8 at callsite("mysource2.cc":13:8 at "mysource3.cc":100:10)))) - // CHECK: for %i0 = 0 to 8 ["foo", mysource.cc:10:8] - for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { - } + // CHECK: } ["foo", mysource.cc:10:8] + for %i0 = 0 to 8 { + } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } <"myPass">["foo", "foo2"] if #set0(%2) { diff --git a/mlir/test/Transforms/strip-debuginfo.mlir b/mlir/test/Transforms/strip-debuginfo.mlir index 618cba83f131..5d1572820718 100644 --- a/mlir/test/Transforms/strip-debuginfo.mlir +++ b/mlir/test/Transforms/strip-debuginfo.mlir @@ -9,9 +9,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { // CHECK: "foo"() : () -> i32 loc(unknown) %1 = "foo"() : () -> i32 loc("foo") - // CHECK: for %i0 = 0 to 8 loc(unknown) - for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { - } + // CHECK: } loc(unknown) + for %i0 = 0 to 8 { + } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } loc(unknown) %2 = constant 4 : index diff --git a/mlir/test/Transforms/unroll.mlir b/mlir/test/Transforms/unroll.mlir index 54c5233430ca..09e55403b7d5 100644 --- a/mlir/test/Transforms/unroll.mlir +++ b/mlir/test/Transforms/unroll.mlir @@ -40,6 +40,9 @@ // UNROLL-BY-4: [[MAP7:#map[0-9]+]] = (d0) -> (d0 + 5) // UNROLL-BY-4: [[MAP8:#map[0-9]+]] = (d0) -> (d0 + 10) // UNROLL-BY-4: [[MAP9:#map[0-9]+]] = (d0) -> (d0 + 15) +// UNROLL-BY-4: [[MAP10:#map[0-9]+]] = (d0) -> (0) +// UNROLL-BY-4: [[MAP11:#map[0-9]+]] = (d0) -> (d0) +// UNROLL-BY-4: [[MAP12:#map[0-9]+]] = ()[s0] -> (0) // CHECK-LABEL: func @loop_nest_simplest() { func @loop_nest_simplest() { @@ -432,7 +435,7 @@ func @loop_nest_single_iteration_after_unroll(%N: index) { // UNROLL-BY-4-LABEL: func @loop_nest_operand1() { func @loop_nest_operand1() { // UNROLL-BY-4: for %i0 = 0 to 100 step 2 { -// UNROLL-BY-4-NEXT: for %i1 = (d0) -> (0)(%i0) to #map{{[0-9]+}}(%i0) step 4 +// UNROLL-BY-4-NEXT: for %i1 = [[MAP10]](%i0) to #map{{[0-9]+}}(%i0) step 4 // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -452,7 +455,7 @@ func @loop_nest_operand1() { // UNROLL-BY-4-LABEL: func @loop_nest_operand2() { func @loop_nest_operand2() { // UNROLL-BY-4: for %i0 = 0 to 100 step 2 { -// UNROLL-BY-4-NEXT: for %i1 = (d0) -> (d0)(%i0) to #map{{[0-9]+}}(%i0) step 4 { +// UNROLL-BY-4-NEXT: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -474,7 +477,7 @@ func @loop_nest_operand2() { func @loop_nest_operand3() { // UNROLL-BY-4: for %i0 = 0 to 100 step 2 { for %i = 0 to 100 step 2 { - // UNROLL-BY-4: for %i1 = (d0) -> (d0)(%i0) to #map{{[0-9]+}}(%i0) step 4 { + // UNROLL-BY-4: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -492,7 +495,7 @@ func @loop_nest_operand3() { func @loop_nest_operand4(%N : index) { // UNROLL-BY-4: for %i0 = 0 to 100 { for %i = 0 to 100 { - // UNROLL-BY-4: for %i1 = ()[s0] -> (0)()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 { + // UNROLL-BY-4: for %i1 = [[MAP12]]()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 { // UNROLL-BY-4: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32