diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 16b4c8e47756..d1ad0a7ddec3 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -35,7 +35,7 @@ class FuncBuilder; /// A utility function to check if a value is defined at the top level of a /// function. A value defined at the top level is always a valid symbol. -bool isTopLevelSymbol(const Value *value); +bool isTopLevelSymbol(Value *value); class AffineOpsDialect : public Dialect { public: @@ -64,15 +64,15 @@ public: ArrayRef operands); /// Returns the affine map to be applied by this operation. - AffineMap getAffineMap() const { + AffineMap getAffineMap() { return getAttrOfType("map").getValue(); } /// Returns true if the result of this operation can be used as dimension id. - bool isValidDim() const; + bool isValidDim(); /// Returns true if the result of this operation is a symbol. - bool isValidSymbol() const; + bool isValidSymbol(); static StringRef getOperationName() { return "affine.apply"; } @@ -87,7 +87,7 @@ public: private: friend class Instruction; - explicit AffineApplyOp(const Instruction *state) : Op(state) {} + explicit AffineApplyOp(Instruction *state) : Op(state) {} }; /// The "for" instruction represents an affine loop nest, defining an SSA value @@ -141,16 +141,13 @@ public: Block *createBody(); /// Get the body of the AffineForOp. - Block *getBody() const { return &getRegion().front(); } + Block *getBody() { return &getRegion().front(); } /// Get the body region of the AffineForOp. - Region &getRegion() const { return getInstruction()->getRegion(0); } + Region &getRegion() { return getInstruction()->getRegion(0); } /// Returns the induction variable for this loop. Value *getInductionVar(); - const Value *getInductionVar() const { - return const_cast(this)->getInductionVar(); - } //===--------------------------------------------------------------------===// // Bounds and step @@ -161,29 +158,27 @@ public: /// Returns operands for the lower bound map. operand_range getLowerBoundOperands(); - const_operand_range getLowerBoundOperands() const; /// Returns operands for the upper bound map. operand_range getUpperBoundOperands(); - const_operand_range getUpperBoundOperands() const; /// Returns information about the lower bound as a single object. - const AffineBound getLowerBound() const; + AffineBound getLowerBound(); /// Returns information about the upper bound as a single object. - const AffineBound getUpperBound() const; + AffineBound getUpperBound(); /// Returns loop step. - int64_t getStep() const { + int64_t getStep() { return getAttr(getStepAttrName()).cast().getInt(); } /// Returns affine map for the lower bound. - AffineMap getLowerBoundMap() const { + AffineMap getLowerBoundMap() { return getAttr(getLowerBoundAttrName()).cast().getValue(); } /// Returns affine map for the upper bound. The upper bound is exclusive. - AffineMap getUpperBoundMap() const { + AffineMap getUpperBoundMap() { return getAttr(getUpperBoundAttrName()).cast().getValue(); } @@ -209,19 +204,19 @@ public: } /// Returns true if the lower bound is constant. - bool hasConstantLowerBound() const; + bool hasConstantLowerBound(); /// Returns true if the upper bound is constant. - bool hasConstantUpperBound() const; + bool hasConstantUpperBound(); /// Returns true if both bounds are constant. - bool hasConstantBounds() const { + bool hasConstantBounds() { return hasConstantLowerBound() && hasConstantUpperBound(); } /// Returns the value of the constant lower bound. /// Fails assertion if the bound is non-constant. - int64_t getConstantLowerBound() const; + int64_t getConstantLowerBound(); /// Returns the value of the constant upper bound. The upper bound is /// exclusive. Fails assertion if the bound is non-constant. - int64_t getConstantUpperBound() const; + int64_t getConstantUpperBound(); /// Sets the lower bound to the given constant value. void setConstantLowerBound(int64_t value); /// Sets the upper bound to the given constant value. @@ -229,19 +224,19 @@ public: /// Returns true if both the lower and upper bound have the same operand lists /// (same operands in the same order). - bool matchingBoundOperandList() const; + bool matchingBoundOperandList(); private: friend class Instruction; - explicit AffineForOp(const Instruction *state) : Op(state) {} + explicit AffineForOp(Instruction *state) : Op(state) {} }; /// Returns if the provided value is the induction variable of a AffineForOp. -bool isForInductionVar(const Value *val); +bool isForInductionVar(Value *val); /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -OpPointer getForInductionVarOwner(const Value *val); +OpPointer getForInductionVarOwner(Value *val); /// Extracts the induction variables from a list of AffineForOps and places them /// in the output argument `ivs`. @@ -262,7 +257,7 @@ public: AffineValueMap getAsAffineValueMap(); unsigned getNumOperands() const { return opEnd - opStart; } - const Value *getOperand(unsigned idx) const { + Value *getOperand(unsigned idx) const { return inst->getInstruction()->getOperand(opStart + idx); } @@ -323,20 +318,14 @@ public: static StringRef getOperationName() { return "if"; } static StringRef getConditionAttrName() { return "condition"; } - IntegerSet getIntegerSet() const; + IntegerSet getIntegerSet(); void setIntegerSet(IntegerSet newSet); /// Returns the 'then' region. Region &getThenBlocks(); - Region &getThenBlocks() const { - return const_cast(this)->getThenBlocks(); - } /// Returns the 'else' blocks. Region &getElseBlocks(); - Region &getElseBlocks() const { - return const_cast(this)->getElseBlocks(); - } bool verify(); static bool parse(OpAsmParser *parser, OperationState *result); @@ -344,14 +333,14 @@ public: private: friend class Instruction; - explicit AffineIfOp(const Instruction *state) : Op(state) {} + explicit AffineIfOp(Instruction *state) : Op(state) {} }; /// Returns true if the given Value can be used as a dimension id. -bool isValidDim(const Value *value); +bool isValidDim(Value *value); /// Returns true if the given Value can be used as a symbol. -bool isValidSymbol(const Value *value); +bool isValidSymbol(Value *value); /// Modifies both `map` and `operands` in-place so as to: /// 1. drop duplicate operands diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 848c9215fe7b..36b82acd1b7d 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -443,16 +443,16 @@ public: /// Sets the identifier corresponding to the specified Value id to a /// constant. Asserts if the 'id' is not found. - void setIdToConstant(const Value &id, int64_t val); + void setIdToConstant(Value &id, int64_t val); /// Looks up the position of the identifier with the specified Value. Returns /// true if found (false otherwise). `pos' is set to the (column) position of /// the identifier. - bool findId(const Value &id, unsigned *pos) const; + bool findId(Value &id, unsigned *pos) const; /// Returns true if an identifier with the specified Value exists, false /// otherwise. - bool containsId(const Value &id) const; + bool containsId(Value &id) const; // Add identifiers of the specified kind - specified positions are relative to // the kind of identifier. The coefficient column corresponding to the added diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index d88c002a2748..4aa8c0463d45 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -66,18 +66,18 @@ public: using super::super; /// Return true if instruction A properly dominates instruction B. - bool properlyDominates(const Instruction *a, const Instruction *b); + bool properlyDominates(Instruction *a, Instruction *b); /// Return true if instruction A dominates instruction B. - bool dominates(const Instruction *a, const Instruction *b) { + bool dominates(Instruction *a, Instruction *b) { return a == b || properlyDominates(a, b); } /// Return true if value A properly dominates instruction B. - bool properlyDominates(const Value *a, const Instruction *b); + bool properlyDominates(Value *a, Instruction *b); /// Return true if instruction A dominates instruction B. - bool dominates(const Value *a, const Instruction *b) { + bool dominates(Value *a, Instruction *b) { return (Instruction *)a->getDefiningInst() == b || properlyDominates(a, b); } @@ -98,10 +98,10 @@ public: using super::super; /// Return true if instruction A properly postdominates instruction B. - bool properlyPostDominates(const Instruction *a, const Instruction *b); + bool properlyPostDominates(Instruction *a, Instruction *b); /// Return true if instruction A postdominates instruction B. - bool postDominates(const Instruction *a, const Instruction *b) { + bool postDominates(Instruction *a, Instruction *b) { return a == b || properlyPostDominates(a, b); } diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index a5222c58cecb..7d5ebeed054a 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -72,7 +72,7 @@ uint64_t getLargestDivisorOfTripCount(OpPointer forOp); /// /// Returns false in cases with more than one AffineApplyOp, this is /// conservative. -bool isAccessInvariant(const Value &iv, const Value &index); +bool isAccessInvariant(Value &iv, Value &index); /// Given an induction variable `iv` of type AffineForOp and `indices` of type /// IndexType, returns the set of `indices` that are independent of `iv`. @@ -83,8 +83,8 @@ bool isAccessInvariant(const Value &iv, const Value &index); /// /// Returns false in cases with more than one AffineApplyOp, this is /// conservative. -llvm::DenseSet> -getInvariantAccesses(const Value &iv, llvm::ArrayRef indices); +llvm::DenseSet> +getInvariantAccesses(Value &iv, llvm::ArrayRef indices); /// Checks whether the loop is structurally vectorizable; i.e.: /// 1. the loop has proper dependence semantics (parallel, reduction, etc); diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 44fe4c0558a9..64bdfb4f9412 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -94,8 +94,8 @@ private: /// aggressive unrolling. As experience has shown, it is generally better to use /// a plain walk over instructions to match flat patterns but the current /// implementation is competitive nonetheless. -using FilterFunctionType = std::function; -static bool defaultFilterFunction(const Instruction &) { return true; }; +using FilterFunctionType = std::function; +static bool defaultFilterFunction(Instruction &) { return true; }; struct NestedPattern { NestedPattern(ArrayRef nested, FilterFunctionType filter = defaultFilterFunction); @@ -182,9 +182,9 @@ NestedPattern For(ArrayRef nested = {}); NestedPattern For(FilterFunctionType filter, ArrayRef nested = {}); -bool isParallelLoop(const Instruction &inst); -bool isReductionLoop(const Instruction &inst); -bool isLoadOrStore(const Instruction &inst); +bool isParallelLoop(Instruction &inst); +bool isReductionLoop(Instruction &inst); +bool isLoadOrStore(Instruction &inst); } // end namespace matcher } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index aa5b5c54720d..0982849302c1 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -48,12 +48,12 @@ class Value; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. // TODO(bondhugula): handle 'if' inst's. -void getLoopIVs(const Instruction &inst, +void getLoopIVs(Instruction &inst, SmallVectorImpl> *loops); /// Returns the nesting depth of this instruction, i.e., the number of loops /// surrounding this instruction. -unsigned getNestingDepth(const Instruction &stmt); +unsigned getNestingDepth(Instruction &inst); /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted /// at 'forOp'. @@ -231,8 +231,7 @@ LogicalResult boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError = true); /// Returns the number of surrounding loops common to both A and B. -unsigned getNumCommonSurroundingLoops(const Instruction &A, - const Instruction &B); +unsigned getNumCommonSurroundingLoops(Instruction &A, Instruction &B); /// Gets the memory footprint of all data touched in the specified memory space /// in bytes; if the memory space is unspecified, considers all memory spaces. diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index 4982481bf6cd..f8ed1dd28198 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -135,7 +135,7 @@ namespace matcher { /// TODO(ntv): this could all be much simpler if we added a bit that a vector /// type to mark that a vector is a strict super-vector but it still does not /// warrant adding even 1 extra bit in the IR for now. -bool operatesOnSuperVectors(const Instruction &inst, VectorType subVectorType); +bool operatesOnSuperVectors(Instruction &inst, VectorType subVectorType); } // end namespace matcher } // end namespace mlir diff --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h index c25f2151ba30..ffaf56617697 100644 --- a/mlir/include/mlir/Dialect/Traits.h +++ b/mlir/include/mlir/Dialect/Traits.h @@ -32,7 +32,7 @@ namespace OpTrait { // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { -bool verifyCompatibleOperandBroadcast(const Instruction *op); +bool verifyCompatibleOperandBroadcast(Instruction *op); } // namespace impl namespace util { @@ -78,7 +78,7 @@ template class BroadcastableTwoOperandsOneResult : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyCompatibleOperandBroadcast(op); } }; diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index f373f73bf566..d964878237f0 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -173,7 +173,7 @@ public: /// the latter fails. /// TODO: This is very specific functionality that should live somewhere else, /// probably in Dominance.cpp. - Instruction *findAncestorInstInBlock(const Instruction &inst); + Instruction *findAncestorInstInBlock(Instruction &inst); /// This drops all operand uses from instructions within this block, which is /// an essential step in breaking cyclic dependences between references when diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h index 2bac95bc39db..8f4c4ce651f4 100644 --- a/mlir/include/mlir/IR/BlockAndValueMapping.h +++ b/mlir/include/mlir/IR/BlockAndValueMapping.h @@ -37,7 +37,7 @@ public: /// Inserts a new mapping for 'from' to 'to'. If there is an existing mapping, /// it is overwritten. void map(Block *from, Block *to) { valueMap[from] = to; } - void map(const Value *from, Value *to) { valueMap[from] = to; } + void map(Value *from, Value *to) { valueMap[from] = to; } /// Erases a mapping for 'from'. void erase(const IRObjectWithUseList *from) { valueMap.erase(from); } @@ -52,7 +52,7 @@ public: Block *lookupOrNull(Block *from) const { return lookupOrValue(from, (Block *)nullptr); } - Value *lookupOrNull(const Value *from) const { + Value *lookupOrNull(Value *from) const { return lookupOrValue(from, (Value *)nullptr); } diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 44df05f7380f..13b58c40ab32 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -268,12 +268,12 @@ public: /// ( leaving them alone if no entry is present). Replaces references to /// cloned sub-instructions to the corresponding instruction that is copied, /// and adds those mappings to the map. - Instruction *clone(const Instruction &inst, BlockAndValueMapping &mapper) { + Instruction *clone(Instruction &inst, BlockAndValueMapping &mapper) { Instruction *cloneInst = inst.clone(mapper, getContext()); block->getInstructions().insert(insertPoint, cloneInst); return cloneInst; } - Instruction *clone(const Instruction &inst) { + Instruction *clone(Instruction &inst) { Instruction *cloneInst = inst.clone(getContext()); block->getInstructions().insert(insertPoint, cloneInst); return cloneInst; diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index cf1a6e5cb72d..beabbb112378 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -32,7 +32,7 @@ class Type; using DialectConstantDecodeHook = std::function; using DialectConstantFoldHook = std::function, SmallVectorImpl &)>; + Instruction *, ArrayRef, SmallVectorImpl &)>; using DialectExtractElementHook = std::function)>; @@ -57,7 +57,7 @@ public: /// `results` vector. If not, this returns failure and `results` is /// unspecified. DialectConstantFoldHook constantFoldHook = - [](const Instruction *op, ArrayRef operands, + [](Instruction *op, ArrayRef operands, SmallVectorImpl &results) { return failure(); }; /// Registered hook to decode opaque constants associated with this @@ -117,7 +117,7 @@ public: /// Verify an attribute from this dialect on the given instruction. Returns /// true if the verification failed, false otherwise. - virtual bool verifyInstructionAttribute(const Instruction *, NamedAttribute) { + virtual bool verifyInstructionAttribute(Instruction *, NamedAttribute) { return false; } diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index ccdb3fba7da7..4b164627a2aa 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -68,11 +68,11 @@ public: bool resizableOperandList, MLIRContext *context); /// The name of an operation is the key identifier for it. - OperationName getName() const { return name; } + OperationName getName() { return name; } /// If this operation has a registered operation description, return it. /// Otherwise return null. - const AbstractOperation *getAbstractOperation() const { + const AbstractOperation *getAbstractOperation() { return getName().getAbstractOperation(); } @@ -84,29 +84,29 @@ public: /// them alone if no entry is present). Replaces references to cloned /// sub-instructions to the corresponding instruction that is copied, and adds /// those mappings to the map. - Instruction *clone(BlockAndValueMapping &mapper, MLIRContext *context) const; - Instruction *clone(MLIRContext *context) const; + Instruction *clone(BlockAndValueMapping &mapper, MLIRContext *context); + Instruction *clone(MLIRContext *context); /// Returns the instruction block that contains this instruction. - Block *getBlock() const { return block; } + Block *getBlock() { return block; } /// Return the context this operation is associated with. - MLIRContext *getContext() const; + MLIRContext *getContext(); /// The source location the operation was defined or derived from. - Location getLoc() const { return location; } + Location getLoc() { return location; } /// Set the source location the operation was defined or derived from. void setLoc(Location loc) { location = loc; } /// Returns the closest surrounding instruction that contains this instruction /// or nullptr if this is a top-level instruction. - Instruction *getParentInst() const; + Instruction *getParentInst(); /// Returns the function that this instruction is part of. /// The function is determined by traversing the chain of parent instructions. /// Returns nullptr if the instruction is unlinked. - Function *getFunction() const; + Function *getFunction(); /// Destroys this instruction and its subclass data. void destroy(); @@ -130,10 +130,10 @@ public: /// of the parent block. /// Note: This function has an average complexity of O(1), but worst case may /// take O(N) where N is the number of instructions within the parent block. - bool isBeforeInBlock(const Instruction *other) const; + bool isBeforeInBlock(Instruction *other); - void print(raw_ostream &os) const; - void dump() const; + void print(raw_ostream &os); + void dump(); //===--------------------------------------------------------------------===// // Operands @@ -141,9 +141,7 @@ public: /// Returns if the operation has a resizable operation list, i.e. operands can /// be added. - bool hasResizableOperandsList() const { - return getOperandStorage().isResizable(); - } + bool hasResizableOperandsList() { return getOperandStorage().isResizable(); } /// Replace the current operands of this operation with the ones provided in /// 'operands'. If the operands list is not resizable, the size of 'operands' @@ -152,12 +150,9 @@ public: getOperandStorage().setOperands(this, operands); } - unsigned getNumOperands() const { return getOperandStorage().size(); } + unsigned getNumOperands() { return getOperandStorage().size(); } Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } - const Value *getOperand(unsigned idx) const { - return getInstOperand(idx).get(); - } void setOperand(unsigned idx, Value *value) { return getInstOperand(idx).set(value); } @@ -172,77 +167,40 @@ public: /// Returns an iterator on the underlying Value's (Value *). operand_range getOperands(); - // Support const operand iteration. - using const_operand_iterator = - OperandIterator; - using const_operand_range = llvm::iterator_range; - - const_operand_iterator operand_begin() const; - const_operand_iterator operand_end() const; - - /// Returns a const iterator on the underlying Value's (Value *). - llvm::iterator_range getOperands() const; - - ArrayRef getInstOperands() const { - return getOperandStorage().getInstOperands(); - } MutableArrayRef getInstOperands() { return getOperandStorage().getInstOperands(); } InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } - const InstOperand &getInstOperand(unsigned idx) const { - return getInstOperands()[idx]; - } //===--------------------------------------------------------------------===// // Results //===--------------------------------------------------------------------===// /// Return true if there are no users of any results of this operation. - bool use_empty() const; + bool use_empty(); - unsigned getNumResults() const { return numResults; } + unsigned getNumResults() { return numResults; } Value *getResult(unsigned idx) { return &getInstResult(idx); } - const Value *getResult(unsigned idx) const { return &getInstResult(idx); } - // Support non-const result iteration. + // Support result iteration. using result_iterator = ResultIterator; result_iterator result_begin(); result_iterator result_end(); llvm::iterator_range getResults(); - // Support const result iteration. - using const_result_iterator = ResultIterator; - const_result_iterator result_begin() const; - - const_result_iterator result_end() const; - - llvm::iterator_range getResults() const; - - ArrayRef getInstResults() const { - return {getTrailingObjects(), numResults}; - } - MutableArrayRef getInstResults() { return {getTrailingObjects(), numResults}; } InstResult &getInstResult(unsigned idx) { return getInstResults()[idx]; } - const InstResult &getInstResult(unsigned idx) const { - return getInstResults()[idx]; - } - // Support result type iteration. - using result_type_iterator = - ResultTypeIterator; - result_type_iterator result_type_begin() const; - - result_type_iterator result_type_end() const; - - llvm::iterator_range getResultTypes() const; + using result_type_iterator = ResultTypeIterator; + result_type_iterator result_type_begin(); + result_type_iterator result_type_end(); + llvm::iterator_range getResultTypes(); //===--------------------------------------------------------------------===// // Attributes @@ -253,17 +211,17 @@ public: // the lifetime of an instruction. /// Return all of the attributes on this instruction. - ArrayRef getAttrs() const { return attrs.getAttrs(); } + ArrayRef getAttrs() { return attrs.getAttrs(); } /// Return the specified attribute if present, null otherwise. - Attribute getAttr(Identifier name) const { return attrs.get(name); } - Attribute getAttr(StringRef name) const { return attrs.get(name); } + Attribute getAttr(Identifier name) { return attrs.get(name); } + Attribute getAttr(StringRef name) { return attrs.get(name); } - template AttrClass getAttrOfType(Identifier name) const { + template AttrClass getAttrOfType(Identifier name) { return getAttr(name).dyn_cast_or_null(); } - template AttrClass getAttrOfType(StringRef name) const { + template AttrClass getAttrOfType(StringRef name) { return getAttr(name).dyn_cast_or_null(); } @@ -287,16 +245,16 @@ public: //===--------------------------------------------------------------------===// /// Returns the number of regions held by this operation. - unsigned getNumRegions() const { return numRegions; } + unsigned getNumRegions() { return numRegions; } /// Returns the regions held by this operation. - MutableArrayRef getRegions() const { + MutableArrayRef getRegions() { auto *regions = getTrailingObjects(); - return {const_cast(regions), numRegions}; + return {regions, numRegions}; } /// Returns the region held by this operation at position 'index'. - Region &getRegion(unsigned index) const { + Region &getRegion(unsigned index) { assert(index < numRegions && "invalid region index"); return getRegions()[index]; } @@ -308,15 +266,10 @@ public: MutableArrayRef getBlockOperands() { return {getTrailingObjects(), numSuccs}; } - ArrayRef getBlockOperands() const { - return const_cast(this)->getBlockOperands(); - } /// Return the operands of this operation that are *not* successor arguments. - const_operand_range getNonSuccessorOperands() const; operand_range getNonSuccessorOperands(); - const_operand_range getSuccessorOperands(unsigned index) const; operand_range getSuccessorOperands(unsigned index); Value *getSuccessorOperand(unsigned succIndex, unsigned opIndex) { @@ -324,19 +277,15 @@ public: assert(opIndex < getNumSuccessorOperands(succIndex)); return getOperand(getSuccessorOperandIndex(succIndex) + opIndex); } - const Value *getSuccessorOperand(unsigned succIndex, unsigned index) const { - return const_cast(this)->getSuccessorOperand(succIndex, - index); - } - unsigned getNumSuccessors() const { return numSuccs; } - unsigned getNumSuccessorOperands(unsigned index) const { + unsigned getNumSuccessors() { return numSuccs; } + unsigned getNumSuccessorOperands(unsigned index) { assert(!isKnownNonTerminator() && "only terminators may have successors"); assert(index < getNumSuccessors()); return getTrailingObjects()[index]; } - Block *getSuccessor(unsigned index) const { + Block *getSuccessor(unsigned index) { assert(index < getNumSuccessors()); return getBlockOperands()[index].get(); } @@ -354,21 +303,21 @@ public: /// Get the index of the first operand of the successor at the provided /// index. - unsigned getSuccessorOperandIndex(unsigned index) const; + unsigned getSuccessorOperandIndex(unsigned index); //===--------------------------------------------------------------------===// // Accessors for various properties of operations //===--------------------------------------------------------------------===// /// Returns whether the operation is commutative. - bool isCommutative() const { + bool isCommutative() { if (auto *absOp = getAbstractOperation()) return absOp->hasProperty(OperationProperty::Commutative); return false; } /// Returns whether the operation has side-effects. - bool hasNoSideEffect() const { + bool hasNoSideEffect() { if (auto *absOp = getAbstractOperation()) return absOp->hasProperty(OperationProperty::NoSideEffect); return false; @@ -380,7 +329,7 @@ public: enum class TerminatorStatus { Terminator, NonTerminator, Unknown }; /// Returns the status of whether this operation is a terminator or not. - TerminatorStatus getTerminatorStatus() const { + TerminatorStatus getTerminatorStatus() { if (auto *absOp = getAbstractOperation()) { return absOp->hasProperty(OperationProperty::Terminator) ? TerminatorStatus::Terminator @@ -390,12 +339,12 @@ public: } /// Returns if the operation is known to be a terminator. - bool isKnownTerminator() const { + bool isKnownTerminator() { return getTerminatorStatus() == TerminatorStatus::Terminator; } /// Returns if the operation is known to *not* be a terminator. - bool isKnownNonTerminator() const { + bool isKnownNonTerminator() { return getTerminatorStatus() == TerminatorStatus::NonTerminator; } @@ -405,7 +354,7 @@ public: /// constant folding is successful, this fills in the `results` vector. If /// not, `results` is unspecified. LogicalResult constantFold(ArrayRef operands, - SmallVectorImpl &results) const; + SmallVectorImpl &results); /// Attempt to fold this operation using the Op's registered foldHook. LogicalResult fold(SmallVectorImpl &results); @@ -421,7 +370,7 @@ public: /// The dyn_cast methods perform a dynamic cast from an Instruction to a typed /// Op like DimOp. This returns a null OpPointer on failure. - template OpPointer dyn_cast() const { + template OpPointer dyn_cast() { if (isa()) { return cast(); } else { @@ -432,16 +381,14 @@ public: /// The cast methods perform a cast from an Instruction to a typed Op like /// DimOp. This aborts if the parameter to the template isn't an instance of /// the template type argument. - template OpPointer cast() const { + template OpPointer cast() { assert(isa() && "cast() argument of incompatible type!"); return OpPointer(OpClass(this)); } /// The is methods return true if the operation is a typed op (like DimOp) of /// of the given class. - template bool isa() const { - return OpClass::isClassFor(const_cast(this)); - } + template bool isa() { return OpClass::isClassFor(this); } //===--------------------------------------------------------------------===// // Instruction Walkers @@ -479,21 +426,21 @@ public: /// Emit an error with the op name prefixed, like "'dim' op " which is /// convenient for verifiers. This function always returns true. - bool emitOpError(const Twine &message) const; + bool emitOpError(const Twine &message); /// Emit an error about fatal conditions with this operation, reporting up to /// any diagnostic handlers that may be listening. This function always /// returns true. NOTE: This may terminate the containing application, only /// use when the IR is in an inconsistent state. - bool emitError(const Twine &message) const; + bool emitError(const Twine &message); /// Emit a warning about this operation, reporting up to any diagnostic /// handlers that may be listening. - void emitWarning(const Twine &message) const; + void emitWarning(const Twine &message); /// Emit a note about this operation, reporting up to any diagnostic /// handlers that may be listening. - void emitNote(const Twine &message) const; + void emitNote(const Twine &message); private: Instruction(Location location, OperationName name, unsigned numResults, @@ -508,12 +455,9 @@ private: detail::OperandStorage &getOperandStorage() { return *getTrailingObjects(); } - const detail::OperandStorage &getOperandStorage() const { - return *getTrailingObjects(); - } // Provide a 'getParent' method for ilist_node_with_parent methods. - Block *getParent() const { return getBlock(); } + Block *getParent() { return getBlock(); } /// The instruction block that containts this instruction. Block *block = nullptr; @@ -556,7 +500,7 @@ private: size_t numTrailingObjects(OverloadToken) const { return numSuccs; } }; -inline raw_ostream &operator<<(raw_ostream &os, const Instruction &inst) { +inline raw_ostream &operator<<(raw_ostream &os, Instruction &inst) { inst.print(os); return os; } @@ -573,13 +517,6 @@ public: : IndexedAccessorIterator, ObjectType, ElementType>(object, index) {} - /// Support converting to the const variant. This will be a no-op for const - /// variant. - operator OperandIterator() const { - return OperandIterator(this->object, - this->index); - } - ElementType *operator*() const { return this->object->getOperand(this->index); } @@ -598,18 +535,6 @@ inline auto Instruction::getOperands() -> operand_range { return {operand_begin(), operand_end()}; } -inline auto Instruction::operand_begin() const -> const_operand_iterator { - return const_operand_iterator(this, 0); -} - -inline auto Instruction::operand_end() const -> const_operand_iterator { - return const_operand_iterator(this, getNumOperands()); -} - -inline auto Instruction::getOperands() const -> const_operand_range { - return {operand_begin(), operand_end()}; -} - /// This template implements the result iterators for the Instruction class /// in terms of getResult(idx). template @@ -622,13 +547,6 @@ public: : IndexedAccessorIterator, ObjectType, ElementType>(object, index) {} - /// Support converting to the const variant. This will be a no-op for const - /// variant. - operator ResultIterator() const { - return ResultIterator(this->object, - this->index); - } - ElementType *operator*() const { return this->object->getResult(this->index); } @@ -672,28 +590,15 @@ inline auto Instruction::getResults() -> llvm::iterator_range { return {result_begin(), result_end()}; } -inline auto Instruction::result_begin() const -> const_result_iterator { - return const_result_iterator(this, 0); -} - -inline auto Instruction::result_end() const -> const_result_iterator { - return const_result_iterator(this, getNumResults()); -} - -inline auto Instruction::getResults() const - -> llvm::iterator_range { - return {result_begin(), result_end()}; -} - -inline auto Instruction::result_type_begin() const -> result_type_iterator { +inline auto Instruction::result_type_begin() -> result_type_iterator { return result_type_iterator(this, 0); } -inline auto Instruction::result_type_end() const -> result_type_iterator { +inline auto Instruction::result_type_end() -> result_type_iterator { return result_type_iterator(this, getNumResults()); } -inline auto Instruction::getResultTypes() const +inline auto Instruction::getResultTypes() -> llvm::iterator_range { return {result_type_begin(), result_type_end()}; } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index dc6ddd3df88c..d21e40818d59 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -103,7 +103,7 @@ private: class OpState { public: /// Return the operation that this refers to. - const Instruction *getInstruction() const { return state; } + Instruction *getInstruction() const { return state; } Instruction *getInstruction() { return state; } /// The source location the operation was defined or derived from. @@ -176,7 +176,7 @@ protected: /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, /// so we can cast it away here. - explicit OpState(const Instruction *state) + explicit OpState(Instruction *state) : state(const_cast(state)) {} private: @@ -327,22 +327,22 @@ namespace OpTrait { // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { -bool verifyZeroOperands(const Instruction *op); -bool verifyOneOperand(const Instruction *op); -bool verifyNOperands(const Instruction *op, unsigned numOperands); -bool verifyAtLeastNOperands(const Instruction *op, unsigned numOperands); -bool verifyOperandsAreIntegerLike(const Instruction *op); -bool verifySameTypeOperands(const Instruction *op); -bool verifyZeroResult(const Instruction *op); -bool verifyOneResult(const Instruction *op); -bool verifyNResults(const Instruction *op, unsigned numOperands); -bool verifyAtLeastNResults(const Instruction *op, unsigned numOperands); -bool verifySameOperandsAndResultShape(const Instruction *op); -bool verifySameOperandsAndResultType(const Instruction *op); -bool verifyResultsAreBoolLike(const Instruction *op); -bool verifyResultsAreFloatLike(const Instruction *op); -bool verifyResultsAreIntegerLike(const Instruction *op); -bool verifyIsTerminator(const Instruction *op); +bool verifyZeroOperands(Instruction *op); +bool verifyOneOperand(Instruction *op); +bool verifyNOperands(Instruction *op, unsigned numOperands); +bool verifyAtLeastNOperands(Instruction *op, unsigned numOperands); +bool verifyOperandsAreIntegerLike(Instruction *op); +bool verifySameTypeOperands(Instruction *op); +bool verifyZeroResult(Instruction *op); +bool verifyOneResult(Instruction *op); +bool verifyNResults(Instruction *op, unsigned numOperands); +bool verifyAtLeastNResults(Instruction *op, unsigned numOperands); +bool verifySameOperandsAndResultShape(Instruction *op); +bool verifySameOperandsAndResultType(Instruction *op); +bool verifyResultsAreBoolLike(Instruction *op); +bool verifyResultsAreFloatLike(Instruction *op); +bool verifyResultsAreIntegerLike(Instruction *op); +bool verifyIsTerminator(Instruction *op); } // namespace impl /// Helper class for implementing traits. Clients are not expected to interact @@ -361,13 +361,13 @@ protected: auto *base = static_cast(concrete); return base->getInstruction(); } - const Instruction *getInstruction() const { + Instruction *getInstruction() const { return const_cast(this)->getInstruction(); } /// Provide default implementations of trait hooks. This allows traits to /// provide exactly the overrides they care about. - static bool verifyTrait(const Instruction *op) { return false; } + static bool verifyTrait(Instruction *op) { return false; } static AbstractOperation::OperationProperties getTraitProperties() { return 0; } @@ -378,7 +378,7 @@ protected: template class ZeroOperands : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyZeroOperands(op); } @@ -393,17 +393,13 @@ private: template class OneOperand : public TraitBase { public: - const Value *getOperand() const { - return this->getInstruction()->getOperand(0); - } - - Value *getOperand() { return this->getInstruction()->getOperand(0); } + Value *getOperand() const { return this->getInstruction()->getOperand(0); } void setOperand(Value *value) { this->getInstruction()->setOperand(0, value); } - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyOneOperand(op); } }; @@ -418,11 +414,7 @@ public: template class Impl : public TraitBase::Impl> { public: - const Value *getOperand(unsigned i) const { - return this->getInstruction()->getOperand(i); - } - - Value *getOperand(unsigned i) { + Value *getOperand(unsigned i) const { return this->getInstruction()->getOperand(i); } @@ -430,7 +422,7 @@ public: this->getInstruction()->setOperand(i, value); } - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyNOperands(op, N); } }; @@ -449,11 +441,7 @@ public: unsigned getNumOperands() const { return this->getInstruction()->getNumOperands(); } - const Value *getOperand(unsigned i) const { - return this->getInstruction()->getOperand(i); - } - - Value *getOperand(unsigned i) { + Value *getOperand(unsigned i) const { return this->getInstruction()->getOperand(i); } @@ -473,19 +461,7 @@ public: return this->getInstruction()->getOperands(); } - // Support const operand iteration. - using const_operand_iterator = Instruction::const_operand_iterator; - const_operand_iterator operand_begin() const { - return this->getInstruction()->operand_begin(); - } - const_operand_iterator operand_end() const { - return this->getInstruction()->operand_end(); - } - llvm::iterator_range getOperands() const { - return this->getInstruction()->getOperands(); - } - - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyAtLeastNOperands(op, N); } }; @@ -500,11 +476,7 @@ public: return this->getInstruction()->getNumOperands(); } - const Value *getOperand(unsigned i) const { - return this->getInstruction()->getOperand(i); - } - - Value *getOperand(unsigned i) { + Value *getOperand(unsigned i) const { return this->getInstruction()->getOperand(i); } @@ -522,19 +494,6 @@ public: return this->getInstruction()->operand_end(); } operand_range getOperands() { return this->getInstruction()->getOperands(); } - - // Support const operand iteration. - using const_operand_iterator = Instruction::const_operand_iterator; - using const_operand_range = Instruction::const_operand_range; - const_operand_iterator operand_begin() const { - return this->getInstruction()->operand_begin(); - } - const_operand_iterator operand_end() const { - return this->getInstruction()->operand_end(); - } - const_operand_range getOperands() const { - return this->getInstruction()->getOperands(); - } }; /// This class provides return value APIs for ops that are known to have @@ -542,7 +501,7 @@ public: template class ZeroResult : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyZeroResult(op); } }; @@ -552,10 +511,7 @@ public: template class OneResult : public TraitBase { public: - Value *getResult() { return this->getInstruction()->getResult(0); } - const Value *getResult() const { - return this->getInstruction()->getResult(0); - } + Value *getResult() const { return this->getInstruction()->getResult(0); } Type getType() const { return getResult()->getType(); } @@ -566,9 +522,7 @@ public: getResult()->replaceAllUsesWith(newValue); } - static bool verifyTrait(const Instruction *op) { - return impl::verifyOneResult(op); - } + static bool verifyTrait(Instruction *op) { return impl::verifyOneResult(op); } }; /// This class provides the API for ops that are known to have a specified @@ -583,17 +537,13 @@ public: public: static unsigned getNumResults() { return N; } - const Value *getResult(unsigned i) const { - return this->getInstruction()->getResult(i); - } - - Value *getResult(unsigned i) { + Value *getResult(unsigned i) const { return this->getInstruction()->getResult(i); } Type getType(unsigned i) const { return getResult(i)->getType(); } - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyNResults(op, N); } }; @@ -609,17 +559,13 @@ public: template class Impl : public TraitBase::Impl> { public: - const Value *getResult(unsigned i) const { - return this->getInstruction()->getResult(i); - } - - Value *getResult(unsigned i) { + Value *getResult(unsigned i) const { return this->getInstruction()->getResult(i); } Type getType(unsigned i) const { return getResult(i)->getType(); } - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyAtLeastNResults(op, N); } }; @@ -634,12 +580,10 @@ public: return this->getInstruction()->getNumResults(); } - const Value *getResult(unsigned i) const { + Value *getResult(unsigned i) const { return this->getInstruction()->getResult(i); } - Value *getResult(unsigned i) { return this->getInstruction()->getResult(i); } - void setResult(unsigned i, Value *value) { this->getInstruction()->setResult(i, value); } @@ -653,18 +597,6 @@ public: llvm::iterator_range getResults() { return this->getInstruction()->getResults(); } - - // Support const result iteration. - using const_result_iterator = Instruction::const_result_iterator; - const_result_iterator result_begin() const { - return this->getInstruction()->result_begin(); - } - const_result_iterator result_end() const { - return this->getInstruction()->result_end(); - } - llvm::iterator_range getResults() const { - return this->getInstruction()->getResults(); - } }; /// This class provides verification for ops that are known to have the same @@ -674,7 +606,7 @@ template class SameOperandsAndResultShape : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifySameOperandsAndResultShape(op); } }; @@ -689,7 +621,7 @@ template class SameOperandsAndResultType : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifySameOperandsAndResultType(op); } }; @@ -699,7 +631,7 @@ public: template class ResultsAreBoolLike : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyResultsAreBoolLike(op); } }; @@ -710,7 +642,7 @@ template class ResultsAreFloatLike : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyResultsAreFloatLike(op); } }; @@ -721,7 +653,7 @@ template class ResultsAreIntegerLike : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyResultsAreIntegerLike(op); } }; @@ -752,7 +684,7 @@ template class OperandsAreIntegerLike : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyOperandsAreIntegerLike(op); } }; @@ -762,7 +694,7 @@ public: template class SameTypeOperands : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifySameTypeOperands(op); } }; @@ -775,7 +707,7 @@ public: return static_cast( OperationProperty::Terminator); } - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyIsTerminator(op); } @@ -820,10 +752,7 @@ class Op : public OpState, Traits...>::value> { public: /// Return the operation that this refers to. - const Instruction *getInstruction() const { - return OpState::getInstruction(); - } - Instruction *getInstruction() { return OpState::getInstruction(); } + Instruction *getInstruction() const { return OpState::getInstruction(); } /// Return true if this "op class" can match against the specified operation. /// This hook can be overridden with a more specific implementation in @@ -875,20 +804,20 @@ public: using ConcreteOpType = ConcreteType; protected: - explicit Op(const Instruction *state) : OpState(state) {} + explicit Op(Instruction *state) : OpState(state) {} private: template struct BaseVerifier; template struct BaseVerifier { - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return First::verifyTrait(op) || BaseVerifier::verifyTrait(op); } }; template struct BaseVerifier { - static bool verifyTrait(const Instruction *op) { return false; } + static bool verifyTrait(Instruction *op) { return false; } }; template struct BaseProperties; @@ -917,7 +846,7 @@ bool parseBinaryOp(OpAsmParser *parser, OperationState *result); // Prints the given binary `op` in custom assembly form if both the two operands // and the result have the same time. Otherwise, prints the generic assembly // form. -void printBinaryOp(const Instruction *op, OpAsmPrinter *p); +void printBinaryOp(Instruction *op, OpAsmPrinter *p); } // namespace impl // These functions are out-of-line implementations of the methods in CastOp, @@ -926,7 +855,7 @@ namespace impl { void buildCastOp(Builder *builder, OperationState *result, Value *source, Type destType); bool parseCastOp(OpAsmParser *parser, OperationState *result); -void printCastOp(const Instruction *op, OpAsmPrinter *p); +void printCastOp(Instruction *op, OpAsmPrinter *p); } // namespace impl /// This template is used for operations that are cast operations, that have a @@ -951,7 +880,7 @@ public: } protected: - explicit CastOp(const Instruction *state) + explicit CastOp(Instruction *state) : Op(state) {} }; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index f8465afe0ea7..ae63f485d321 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -48,7 +48,7 @@ public: virtual raw_ostream &getStream() const = 0; /// Print implementations for various things an operation contains. - virtual void printOperand(const Value *value) = 0; + virtual void printOperand(Value *value) = 0; /// Print a comma separated list of operands. template @@ -76,8 +76,7 @@ public: /// Print a successor, and use list, of a terminator operation given the /// terminator and the successor index. - virtual void printSuccessorAndUseList(const Instruction *term, - unsigned index) = 0; + virtual void printSuccessorAndUseList(Instruction *term, unsigned index) = 0; /// If the specified operation has attributes, print out an attribute /// dictionary with their values. elidedAttrs allows the client to ignore @@ -87,7 +86,7 @@ public: ArrayRef elidedAttrs = {}) = 0; /// Print the entire operation with the default generic assembly form. - virtual void printGenericOp(const Instruction *op) = 0; + virtual void printGenericOp(Instruction *op) = 0; /// Prints a region. virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true) = 0; @@ -98,7 +97,7 @@ private: }; // Make the implementations convenient to use. -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Value &value) { +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value &value) { p.printOperand(&value); return p; } diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 8be6b34bacc2..6796131c0525 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -86,8 +86,8 @@ public: /// This hook implements the AsmPrinter for this operation. void (&printAssembly)(Instruction *op, OpAsmPrinter *p); - /// This hook implements the verifier for this operation. It should emit an - /// error message and returns true if a problem is detected, or return false + /// This hook implements the verifier for this operation. It should emits an + /// error message and returns true if a problem is detected, or returns false /// if everything is ok. bool (&verifyInvariants)(Instruction *op); diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 3d4493b0d6b9..0529a5a2ebd1 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -97,7 +97,7 @@ public: /// Return the owner of this operand. Instruction *getOwner() { return owner; } - const Instruction *getOwner() const { return owner; } + Instruction *getOwner() const { return owner; } /// \brief Remove this use of the operand. void drop() { @@ -176,13 +176,13 @@ public: : IROperand(owner, value) {} /// Return the current value being used by this operand. - IRValueTy *get() const { return (IRValueTy *)IROperand::get(); } + IRValueTy *get() { return (IRValueTy *)IROperand::get(); } /// Set the current value being used by this operand. void set(IRValueTy *newValue) { IROperand::set(newValue); } /// Return which operand this is in the operand list of the User. - unsigned getOperandNumber() const; + unsigned getOperandNumber(); }; /// An iterator over all uses of a ValueBase. diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index fde405305961..dada49bc58f2 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -48,9 +48,9 @@ public: ~Value() {} - Kind getKind() const { return typeAndKind.getInt(); } + Kind getKind() { return typeAndKind.getInt(); } - Type getType() const { return typeAndKind.getPointer(); } + Type getType() { return typeAndKind.getPointer(); } /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns @@ -60,26 +60,23 @@ public: } /// Return the function that this Value is defined in. - Function *getFunction() const; + Function *getFunction(); /// If this value is the result of an operation, return the instruction /// that defines it. Instruction *getDefiningInst(); - const Instruction *getDefiningInst() const { - return const_cast(this)->getDefiningInst(); - } using use_iterator = ValueUseIterator; using use_range = llvm::iterator_range; - inline use_iterator use_begin() const; - inline use_iterator use_end() const; + inline use_iterator use_begin(); + inline use_iterator use_end(); /// Returns a range of all uses, which is useful for iterating over all uses. - inline use_range getUses() const; + inline use_range getUses(); - void print(raw_ostream &os) const; - void dump() const; + void print(raw_ostream &os); + void dump(); protected: Value(Kind kind, Type type) : typeAndKind(type, kind) {} @@ -88,21 +85,19 @@ private: const llvm::PointerIntPair typeAndKind; }; -inline raw_ostream &operator<<(raw_ostream &os, const Value &value) { +inline raw_ostream &operator<<(raw_ostream &os, Value &value) { value.print(os); return os; } // Utility functions for iterating through Value uses. -inline auto Value::use_begin() const -> use_iterator { +inline auto Value::use_begin() -> use_iterator { return use_iterator((InstOperand *)getFirstUse()); } -inline auto Value::use_end() const -> use_iterator { - return use_iterator(nullptr); -} +inline auto Value::use_end() -> use_iterator { return use_iterator(nullptr); } -inline auto Value::getUses() const -> llvm::iterator_range { +inline auto Value::getUses() -> llvm::iterator_range { return {use_begin(), use_end()}; } @@ -110,19 +105,19 @@ inline auto Value::getUses() const -> llvm::iterator_range { class BlockArgument : public Value { public: static bool classof(const Value *value) { - return value->getKind() == Kind::BlockArgument; + return const_cast(value)->getKind() == Kind::BlockArgument; } /// Return the function that this argument is defined in. - Function *getFunction() const; + Function *getFunction(); - Block *getOwner() const { return owner; } + Block *getOwner() { return owner; } /// Returns the number of this argument. - unsigned getArgNumber() const; + unsigned getArgNumber(); /// Returns if the current argument is a function argument. - bool isFunctionArgument() const; + bool isFunctionArgument(); private: friend class Block; // For access to private constructor. @@ -142,14 +137,13 @@ public: : Value(Value::Kind::InstResult, type), owner(owner) {} static bool classof(const Value *value) { - return value->getKind() == Kind::InstResult; + return const_cast(value)->getKind() == Kind::InstResult; } Instruction *getOwner() { return owner; } - const Instruction *getOwner() const { return owner; } /// Returns the number of this result. - unsigned getResultNumber() const; + unsigned getResultNumber(); private: /// The owner of this operand. diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index 827335bf5a97..1cd9392b984e 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -35,7 +35,7 @@ class Builder; namespace detail { /// A custom binary operation printer that omits the "std." prefix from the /// operation names. -void printStandardBinaryOp(const Instruction *op, OpAsmPrinter *p); +void printStandardBinaryOp(Instruction *op, OpAsmPrinter *p); } // namespace detail class StandardOpsDialect : public Dialect { @@ -69,9 +69,7 @@ class AllocOp : public Op { public: /// The result of an alloc is always a MemRefType. - MemRefType getType() const { - return getResult()->getType().cast(); - } + MemRefType getType() { return getResult()->getType().cast(); } static StringRef getOperationName() { return "std.alloc"; } @@ -86,7 +84,7 @@ public: private: friend class Instruction; - explicit AllocOp(const Instruction *state) : Op(state) {} + explicit AllocOp(Instruction *state) : Op(state) {} }; /// The "br" operation represents a branch instruction in a function. @@ -113,7 +111,6 @@ public: /// Return the block this branch jumps to. Block *getDest(); - Block *getDest() const { return const_cast(this)->getDest(); } void setDest(Block *block); /// Erase the operand at 'index' from the operand list. @@ -121,7 +118,7 @@ public: private: friend class Instruction; - explicit BranchOp(const Instruction *state) : Op(state) {} + explicit BranchOp(Instruction *state) : Op(state) {} }; /// The "call" operation represents a direct call to a function. The operands @@ -138,21 +135,15 @@ public: static void build(Builder *builder, OperationState *result, Function *callee, ArrayRef operands); - Function *getCallee() const { + Function *getCallee() { return getAttrOfType("callee").getValue(); } /// Get the argument operands to the called function. - llvm::iterator_range getArgOperands() const { - return {arg_operand_begin(), arg_operand_end()}; - } llvm::iterator_range getArgOperands() { return {arg_operand_begin(), arg_operand_end()}; } - const_operand_iterator arg_operand_begin() const { return operand_begin(); } - const_operand_iterator arg_operand_end() const { return operand_end(); } - operand_iterator arg_operand_begin() { return operand_begin(); } operand_iterator arg_operand_end() { return operand_end(); } @@ -163,7 +154,7 @@ public: protected: friend class Instruction; - explicit CallOp(const Instruction *state) : Op(state) {} + explicit CallOp(Instruction *state) : Op(state) {} }; /// The "call_indirect" operation represents an indirect call to a value of @@ -182,20 +173,13 @@ public: static void build(Builder *builder, OperationState *result, Value *callee, ArrayRef operands); - const Value *getCallee() const { return getOperand(0); } Value *getCallee() { return getOperand(0); } /// Get the argument operands to the called function. - llvm::iterator_range getArgOperands() const { - return {arg_operand_begin(), arg_operand_end()}; - } llvm::iterator_range getArgOperands() { return {arg_operand_begin(), arg_operand_end()}; } - const_operand_iterator arg_operand_begin() const { return ++operand_begin(); } - const_operand_iterator arg_operand_end() const { return operand_end(); } - operand_iterator arg_operand_begin() { return ++operand_begin(); } operand_iterator arg_operand_end() { return operand_end(); } @@ -208,7 +192,7 @@ public: protected: friend class Instruction; - explicit CallIndirectOp(const Instruction *state) : Op(state) {} + explicit CallIndirectOp(Instruction *state) : Op(state) {} }; /// The predicate indicates the type of the comparison to perform: @@ -274,7 +258,7 @@ public: private: friend class Instruction; - explicit CmpIOp(const Instruction *state) : Op(state) {} + explicit CmpIOp(Instruction *state) : Op(state) {} }; /// The "cond_br" operation represents a conditional branch instruction in a @@ -314,29 +298,20 @@ public: MLIRContext *context); // The condition operand is the first operand in the list. - Value *getCondition() { return getOperand(0); } - const Value *getCondition() const { return getOperand(0); } + Value *getCondition() const { return getOperand(0); } /// Return the destination if the condition is true. Block *getTrueDest(); - Block *getTrueDest() const { - return const_cast(this)->getTrueDest(); - } /// Return the destination if the condition is false. Block *getFalseDest(); - Block *getFalseDest() const { - return const_cast(this)->getFalseDest(); - } // Accessors for operands to the 'true' destination. Value *getTrueOperand(unsigned idx) { assert(idx < getNumTrueOperands()); return getOperand(getTrueDestOperandIndex() + idx); } - const Value *getTrueOperand(unsigned idx) const { - return const_cast(this)->getTrueOperand(idx); - } + void setTrueOperand(unsigned idx, Value *value) { assert(idx < getNumTrueOperands()); setOperand(getTrueDestOperandIndex() + idx, value); @@ -352,16 +327,6 @@ public: return {true_operand_begin(), true_operand_end()}; } - const_operand_iterator true_operand_begin() const { - return operand_begin() + getTrueDestOperandIndex(); - } - const_operand_iterator true_operand_end() const { - return true_operand_begin() + getNumTrueOperands(); - } - llvm::iterator_range getTrueOperands() const { - return {true_operand_begin(), true_operand_end()}; - } - unsigned getNumTrueOperands() const; /// Erase the operand at 'index' from the true operand list. @@ -372,7 +337,7 @@ public: assert(idx < getNumFalseOperands()); return getOperand(getFalseDestOperandIndex() + idx); } - const Value *getFalseOperand(unsigned idx) const { + Value *getFalseOperand(unsigned idx) const { return const_cast(this)->getFalseOperand(idx); } void setFalseOperand(unsigned idx, Value *value) { @@ -388,16 +353,6 @@ public: return {false_operand_begin(), false_operand_end()}; } - const_operand_iterator false_operand_begin() const { - return true_operand_end(); - } - const_operand_iterator false_operand_end() const { - return false_operand_begin() + getNumFalseOperands(); - } - llvm::iterator_range getFalseOperands() const { - return {false_operand_begin(), false_operand_end()}; - } - unsigned getNumFalseOperands() const; /// Erase the operand at 'index' from the false operand list. @@ -413,7 +368,7 @@ private: } friend class Instruction; - explicit CondBranchOp(const Instruction *state) : Op(state) {} + explicit CondBranchOp(Instruction *state) : Op(state) {} }; /// The "constant" operation requires a single attribute named "value". @@ -445,7 +400,7 @@ public: protected: friend class Instruction; - explicit ConstantOp(const Instruction *state) : Op(state) {} + explicit ConstantOp(Instruction *state) : Op(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -467,7 +422,7 @@ public: private: friend class Instruction; - explicit ConstantFloatOp(const Instruction *state) : ConstantOp(state) {} + explicit ConstantFloatOp(Instruction *state) : ConstantOp(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -494,7 +449,7 @@ public: private: friend class Instruction; - explicit ConstantIntOp(const Instruction *state) : ConstantOp(state) {} + explicit ConstantIntOp(Instruction *state) : ConstantOp(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -515,7 +470,7 @@ public: private: friend class Instruction; - explicit ConstantIndexOp(const Instruction *state) : ConstantOp(state) {} + explicit ConstantIndexOp(Instruction *state) : ConstantOp(state) {} }; /// The "dealloc" operation frees the region of memory referenced by a memref @@ -531,8 +486,7 @@ private: class DeallocOp : public Op { public: - Value *getMemRef() { return getOperand(); } - const Value *getMemRef() const { return getOperand(); } + Value *getMemRef() const { return getOperand(); } void setMemRef(Value *value) { setOperand(value); } static StringRef getOperationName() { return "std.dealloc"; } @@ -547,7 +501,7 @@ public: private: friend class Instruction; - explicit DeallocOp(const Instruction *state) : Op(state) {} + explicit DeallocOp(Instruction *state) : Op(state) {} }; /// The "dim" operation takes a memref or tensor operand and returns an @@ -578,7 +532,7 @@ public: private: friend class Instruction; - explicit DimOp(const Instruction *state) : Op(state) {} + explicit DimOp(Instruction *state) : Op(state) {} }; // DmaStartOp starts a non-blocking DMA operation that transfers data from a @@ -629,22 +583,19 @@ public: Value *elementsPerStride = nullptr); // Returns the source MemRefType for this DMA operation. - const Value *getSrcMemRef() const { return getOperand(0); } + Value *getSrcMemRef() const { return getOperand(0); } // Returns the rank (number of indices) of the source MemRefType. unsigned getSrcMemRefRank() const { return getSrcMemRef()->getType().cast().getRank(); } // Returns the source memerf indices for this DMA operation. - llvm::iterator_range - getSrcIndices() const { + llvm::iterator_range getSrcIndices() { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_begin() + 1 + getSrcMemRefRank()}; } // Returns the destination MemRefType for this DMA operations. - const Value *getDstMemRef() const { - return getOperand(1 + getSrcMemRefRank()); - } + Value *getDstMemRef() const { return getOperand(1 + getSrcMemRefRank()); } // Returns the rank (number of indices) of the destination MemRefType. unsigned getDstMemRefRank() const { return getDstMemRef()->getType().cast().getRank(); @@ -657,20 +608,19 @@ public: } // Returns the destination memref indices for this DMA operation. - llvm::iterator_range - getDstIndices() const { + llvm::iterator_range getDstIndices() { return {getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1, getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1 + getDstMemRefRank()}; } // Returns the number of elements being transferred by this DMA operation. - const Value *getNumElements() const { + Value *getNumElements() const { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); } // Returns the Tag MemRef for this DMA operation. - const Value *getTagMemRef() const { + Value *getTagMemRef() const { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); } // Returns the rank (number of indices) of the tag MemRefType. @@ -679,8 +629,7 @@ public: } // Returns the tag memref index for this DMA operation. - llvm::iterator_range - getTagIndices() const { + llvm::iterator_range getTagIndices() const { unsigned tagIndexStartPos = 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; return {getInstruction()->operand_begin() + tagIndexStartPos, @@ -725,22 +674,16 @@ public: return nullptr; return getOperand(getNumOperands() - 1 - 1); } - const Value *getStride() const { - return const_cast(this)->getStride(); - } Value *getNumElementsPerStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1); } - const Value *getNumElementsPerStride() const { - return const_cast(this)->getNumElementsPerStride(); - } protected: friend class Instruction; - explicit DmaStartOp(const Instruction *state) : Op(state) {} + explicit DmaStartOp(Instruction *state) : Op(state) {} }; // DmaWaitOp blocks until the completion of a DMA operation associated with the @@ -765,12 +708,10 @@ public: static StringRef getOperationName() { return "std.dma_wait"; } // Returns the Tag MemRef associated with the DMA operation being waited on. - const Value *getTagMemRef() const { return getOperand(0); } - Value *getTagMemRef() { return getOperand(0); } + Value *getTagMemRef() const { return getOperand(0); } // Returns the tag memref index for this DMA operation. - llvm::iterator_range - getTagIndices() const { + llvm::iterator_range getTagIndices() const { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_begin() + 1 + getTagMemRefRank()}; } @@ -781,9 +722,7 @@ public: } // Returns the number of elements transferred in the associated DMA operation. - const Value *getNumElements() const { - return getOperand(1 + getTagMemRefRank()); - } + Value *getNumElements() const { return getOperand(1 + getTagMemRefRank()); } static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); @@ -792,7 +731,7 @@ public: protected: friend class Instruction; - explicit DmaWaitOp(const Instruction *state) : Op(state) {} + explicit DmaWaitOp(Instruction *state) : Op(state) {} }; /// The "extract_element" op reads a tensor or vector and returns one element @@ -813,19 +752,13 @@ public: static void build(Builder *builder, OperationState *result, Value *aggregate, ArrayRef indices = {}); - Value *getAggregate() { return getOperand(0); } - const Value *getAggregate() const { return getOperand(0); } + Value *getAggregate() const { return getOperand(0); } llvm::iterator_range getIndices() { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_end()}; } - llvm::iterator_range getIndices() const { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_end()}; - } - static StringRef getOperationName() { return "std.extract_element"; } // Hooks to customize behavior of this op. @@ -836,7 +769,7 @@ public: private: friend class Instruction; - explicit ExtractElementOp(const Instruction *state) : Op(state) {} + explicit ExtractElementOp(Instruction *state) : Op(state) {} }; /// The "load" op reads an element from a memref specified by an index list. The @@ -854,8 +787,7 @@ public: static void build(Builder *builder, OperationState *result, Value *memref, ArrayRef indices = {}); - Value *getMemRef() { return getOperand(0); } - const Value *getMemRef() const { return getOperand(0); } + Value *getMemRef() const { return getOperand(0); } void setMemRef(Value *value) { setOperand(0, value); } MemRefType getMemRefType() const { return getMemRef()->getType().cast(); @@ -866,11 +798,6 @@ public: getInstruction()->operand_end()}; } - llvm::iterator_range getIndices() const { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_end()}; - } - static StringRef getOperationName() { return "std.load"; } bool verify(); @@ -881,7 +808,7 @@ public: private: friend class Instruction; - explicit LoadOp(const Instruction *state) : Op(state) {} + explicit LoadOp(Instruction *state) : Op(state) {} }; /// The "memref_cast" operation converts a memref from one type to an equivalent @@ -914,7 +841,7 @@ public: private: friend class Instruction; - explicit MemRefCastOp(const Instruction *state) : CastOp(state) {} + explicit MemRefCastOp(Instruction *state) : CastOp(state) {} }; /// The "return" operation represents a return instruction within a function. @@ -941,7 +868,7 @@ public: private: friend class Instruction; - explicit ReturnOp(const Instruction *state) : Op(state) {} + explicit ReturnOp(Instruction *state) : Op(state) {} }; /// The "select" operation chooses one value based on a binary condition @@ -965,18 +892,15 @@ public: void print(OpAsmPrinter *p); bool verify(); - Value *getCondition() { return getOperand(0); } - const Value *getCondition() const { return getOperand(0); } - Value *getTrueValue() { return getOperand(1); } - const Value *getTrueValue() const { return getOperand(1); } - Value *getFalseValue() { return getOperand(2); } - const Value *getFalseValue() const { return getOperand(2); } + Value *getCondition() const { return getOperand(0); } + Value *getTrueValue() const { return getOperand(1); } + Value *getFalseValue() const { return getOperand(2); } Value *fold(); private: friend class Instruction; - explicit SelectOp(const Instruction *state) : Op(state) {} + explicit SelectOp(Instruction *state) : Op(state) {} }; /// The "store" op writes an element to a memref specified by an index list. @@ -997,13 +921,11 @@ public: Value *valueToStore, Value *memref, ArrayRef indices = {}); - Value *getValueToStore() { return getOperand(0); } - const Value *getValueToStore() const { return getOperand(0); } + Value *getValueToStore() const { return getOperand(0); } Value *getMemRef() { return getOperand(1); } - const Value *getMemRef() const { return getOperand(1); } void setMemRef(Value *value) { setOperand(1, value); } - MemRefType getMemRefType() const { + MemRefType getMemRefType() { return getMemRef()->getType().cast(); } @@ -1012,11 +934,6 @@ public: getInstruction()->operand_end()}; } - llvm::iterator_range getIndices() const { - return {getInstruction()->operand_begin() + 2, - getInstruction()->operand_end()}; - } - static StringRef getOperationName() { return "std.store"; } bool verify(); @@ -1028,7 +945,7 @@ public: private: friend class Instruction; - explicit StoreOp(const Instruction *state) : Op(state) {} + explicit StoreOp(Instruction *state) : Op(state) {} }; /// The "tensor_cast" operation converts a tensor from one type to an equivalent @@ -1046,9 +963,7 @@ public: static StringRef getOperationName() { return "std.tensor_cast"; } /// The result of a tensor_cast is always a tensor. - TensorType getType() const { - return getResult()->getType().cast(); - } + TensorType getType() { return getResult()->getType().cast(); } void print(OpAsmPrinter *p); @@ -1056,13 +971,13 @@ public: private: friend class Instruction; - explicit TensorCastOp(const Instruction *state) : CastOp(state) {} + explicit TensorCastOp(Instruction *state) : CastOp(state) {} }; /// Prints dimension and symbol list. -void printDimAndSymbolList(Instruction::const_operand_iterator begin, - Instruction::const_operand_iterator end, - unsigned numDims, OpAsmPrinter *p); +void printDimAndSymbolList(Instruction::operand_iterator begin, + Instruction::operand_iterator end, unsigned numDims, + OpAsmPrinter *p); /// Parses dimension and symbol list and returns true if parsing failed. bool parseDimAndSymbolList(OpAsmParser *parser, diff --git a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h index 286842338d88..bb9fb8c5b660 100644 --- a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h +++ b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h @@ -102,22 +102,18 @@ public: VectorType vectorType, Value *srcMemRef, ArrayRef srcIndices, AffineMap permutationMap, Optional paddingValue = None); - VectorType getResultType() const { + VectorType getResultType() { return getResult()->getType().cast(); } Value *getVector() { return getResult(); } - const Value *getVector() const { return getResult(); } Value *getMemRef() { return getOperand(Offsets::MemRefOffset); } - const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); } - VectorType getVectorType() const { return getResultType(); } - MemRefType getMemRefType() const { + VectorType getVectorType() { return getResultType(); } + MemRefType getMemRefType() { return getMemRef()->getType().cast(); } llvm::iterator_range getIndices(); - llvm::iterator_range getIndices() const; Optional getPaddingValue(); - Optional getPaddingValue() const; - AffineMap getPermutationMap() const; + AffineMap getPermutationMap(); static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); @@ -125,7 +121,7 @@ public: private: friend class Instruction; - explicit VectorTransferReadOp(const Instruction *state) : Op(state) {} + explicit VectorTransferReadOp(Instruction *state) : Op(state) {} }; /// VectorTransferWriteOp performs a blocking write from a super-vector to @@ -172,18 +168,15 @@ public: Value *dstMemRef, ArrayRef dstIndices, AffineMap permutationMap); Value *getVector() { return getOperand(Offsets::VectorOffset); } - const Value *getVector() const { return getOperand(Offsets::VectorOffset); } - VectorType getVectorType() const { + VectorType getVectorType() { return getVector()->getType().cast(); } Value *getMemRef() { return getOperand(Offsets::MemRefOffset); } - const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); } - MemRefType getMemRefType() const { + MemRefType getMemRefType() { return getMemRef()->getType().cast(); } llvm::iterator_range getIndices(); - llvm::iterator_range getIndices() const; - AffineMap getPermutationMap() const; + AffineMap getPermutationMap(); static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); @@ -191,7 +184,7 @@ public: private: friend class Instruction; - explicit VectorTransferWriteOp(const Instruction *state) : Op(state) {} + explicit VectorTransferWriteOp(Instruction *state) : Op(state) {} }; /// VectorTypeCastOp performs a conversion from a memref with scalar element to @@ -215,7 +208,7 @@ public: private: friend class Instruction; - explicit VectorTypeCastOp(const Instruction *state) : Op(state) {} + explicit VectorTypeCastOp(Instruction *state) : Op(state) {} }; } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 78968ae2a7db..0fc076d1a65d 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -69,12 +69,12 @@ class Function; // extra operands, note that 'indexRemap' would just be applied to existing // indices (%i, %j). // TODO(bondhugula): allow extraIndices to be added at any position. -bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, +bool replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, - const Instruction *domInstFilter = nullptr, - const Instruction *postDomInstFilter = nullptr); + Instruction *domInstFilter = nullptr, + Instruction *postDomInstFilter = nullptr); /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of /// its results equal to the number of operands, as a composition diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 6fe6f1d63a77..9cb74187cc12 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -42,7 +42,7 @@ AffineOpsDialect::AffineOpsDialect(MLIRContext *context) /// A utility function to check if a value is defined at the top level of a /// function. A value defined at the top level is always a valid symbol. -bool mlir::isTopLevelSymbol(const Value *value) { +bool mlir::isTopLevelSymbol(Value *value) { if (auto *arg = dyn_cast(value)) return arg->getOwner()->getParent()->getContainingFunction(); return value->getDefiningInst()->getParentInst() == nullptr; @@ -51,7 +51,7 @@ bool mlir::isTopLevelSymbol(const Value *value) { // Value can be used as a dimension id if it is valid as a symbol, or // it is an induction variable, or it is a result of affine apply operation // with dimension id arguments. -bool mlir::isValidDim(const Value *value) { +bool mlir::isValidDim(Value *value) { // The value must be an index type. if (!value->getType().isIndex()) return false; @@ -76,7 +76,7 @@ bool mlir::isValidDim(const Value *value) { // Value can be used as a symbol if it is a constant, or it is defined at // the top level, or it is a result of affine apply operation with symbol // arguments. -bool mlir::isValidSymbol(const Value *value) { +bool mlir::isValidSymbol(Value *value) { // The value must be an index type. if (!value->getType().isIndex()) return false; @@ -105,10 +105,9 @@ bool mlir::isValidSymbol(const Value *value) { /// was an invalid operand. An operation is provided to emit any necessary /// errors. template -static bool -verifyDimAndSymbolIdentifiers(const OpTy &op, - Instruction::const_operand_range operands, - unsigned numDims) { +static bool verifyDimAndSymbolIdentifiers(OpTy &op, + Instruction::operand_range operands, + unsigned numDims) { unsigned opIt = 0; for (auto *operand : operands) { if (opIt++ < numDims) { @@ -189,16 +188,16 @@ bool AffineApplyOp::verify() { // The result of the affine apply operation can be used as a dimension id if it // is a CFG value or if it is an Value, and all the operands are valid // dimension ids. -bool AffineApplyOp::isValidDim() const { +bool AffineApplyOp::isValidDim() { return llvm::all_of(getOperands(), - [](const Value *op) { return mlir::isValidDim(op); }); + [](Value *op) { return mlir::isValidDim(op); }); } // The result of the affine apply operation can be used as a symbol if it is // a CFG value or if it is an Value, and all the operands are symbols. -bool AffineApplyOp::isValidSymbol() const { +bool AffineApplyOp::isValidSymbol() { return llvm::all_of(getOperands(), - [](const Value *op) { return mlir::isValidSymbol(op); }); + [](Value *op) { return mlir::isValidSymbol(op); }); } Attribute AffineApplyOp::constantFold(ArrayRef operands, @@ -1069,13 +1068,13 @@ Block *AffineForOp::createBody() { return body; } -const AffineBound AffineForOp::getLowerBound() const { +AffineBound AffineForOp::getLowerBound() { auto lbMap = getLowerBoundMap(); return AffineBound(OpPointer(*this), 0, lbMap.getNumInputs(), lbMap); } -const AffineBound AffineForOp::getUpperBound() const { +AffineBound AffineForOp::getUpperBound() { auto lbMap = getLowerBoundMap(); auto ubMap = getUpperBoundMap(); return AffineBound(OpPointer(*this), lbMap.getNumInputs(), @@ -1124,19 +1123,19 @@ void AffineForOp::setUpperBoundMap(AffineMap map) { setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); } -bool AffineForOp::hasConstantLowerBound() const { +bool AffineForOp::hasConstantLowerBound() { return getLowerBoundMap().isSingleConstant(); } -bool AffineForOp::hasConstantUpperBound() const { +bool AffineForOp::hasConstantUpperBound() { return getUpperBoundMap().isSingleConstant(); } -int64_t AffineForOp::getConstantLowerBound() const { +int64_t AffineForOp::getConstantLowerBound() { return getLowerBoundMap().getSingleConstantResult(); } -int64_t AffineForOp::getConstantUpperBound() const { +int64_t AffineForOp::getConstantUpperBound() { return getUpperBoundMap().getSingleConstantResult(); } @@ -1154,19 +1153,11 @@ AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; } -AffineForOp::const_operand_range AffineForOp::getLowerBoundOperands() const { - return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; -} - AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; } -AffineForOp::const_operand_range AffineForOp::getUpperBoundOperands() const { - return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; -} - -bool AffineForOp::matchingBoundOperandList() const { +bool AffineForOp::matchingBoundOperandList() { auto lbMap = getLowerBoundMap(); auto ubMap = getUpperBoundMap(); if (lbMap.getNumDims() != ubMap.getNumDims() || @@ -1186,14 +1177,14 @@ bool AffineForOp::matchingBoundOperandList() const { Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); } /// Returns if the provided value is the induction variable of a AffineForOp. -bool mlir::isForInductionVar(const Value *val) { +bool mlir::isForInductionVar(Value *val) { return getForInductionVarOwner(val) != nullptr; } /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -OpPointer mlir::getForInductionVarOwner(const Value *val) { - const BlockArgument *ivArg = dyn_cast(val); +OpPointer mlir::getForInductionVarOwner(Value *val) { + auto *ivArg = dyn_cast(val); if (!ivArg || !ivArg->getOwner()) return OpPointer(); auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst(); @@ -1320,7 +1311,7 @@ void AffineIfOp::print(OpAsmPrinter *p) { /*elidedAttrs=*/getConditionAttrName()); } -IntegerSet AffineIfOp::getIntegerSet() const { +IntegerSet AffineIfOp::getIntegerSet() { return getAttrOfType(getConditionAttrName()).getValue(); } void AffineIfOp::setIntegerSet(IntegerSet newSet) { diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index c24a7688a4dc..0b7d9f831e49 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -118,7 +118,7 @@ LogicalResult mlir::getIndexSet(MutableArrayRef> forOps, // 'indexSet' correspond to the loops surounding 'inst' from outermost to // innermost. // TODO(andydavis) Add support to handle IfInsts surrounding 'inst'. -static LogicalResult getInstIndexSet(const Instruction *inst, +static LogicalResult getInstIndexSet(Instruction *inst, FlatAffineConstraints *indexSet) { // TODO(andydavis) Extend this to gather enclosing IfInsts and consider // factoring it out into a utility function. @@ -147,25 +147,25 @@ static LogicalResult getInstIndexSet(const Instruction *inst, // of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})". class ValuePositionMap { public: - void addSrcValue(const Value *value) { + void addSrcValue(Value *value) { if (addValueAt(value, &srcDimPosMap, numSrcDims)) ++numSrcDims; } - void addDstValue(const Value *value) { + void addDstValue(Value *value) { if (addValueAt(value, &dstDimPosMap, numDstDims)) ++numDstDims; } - void addSymbolValue(const Value *value) { + void addSymbolValue(Value *value) { if (addValueAt(value, &symbolPosMap, numSymbols)) ++numSymbols; } - unsigned getSrcDimOrSymPos(const Value *value) const { + unsigned getSrcDimOrSymPos(Value *value) const { return getDimOrSymPos(value, srcDimPosMap, 0); } - unsigned getDstDimOrSymPos(const Value *value) const { + unsigned getDstDimOrSymPos(Value *value) const { return getDimOrSymPos(value, dstDimPosMap, numSrcDims); } - unsigned getSymPos(const Value *value) const { + unsigned getSymPos(Value *value) const { auto it = symbolPosMap.find(value); assert(it != symbolPosMap.end()); return numSrcDims + numDstDims + it->second; @@ -177,7 +177,7 @@ public: unsigned getNumSymbols() const { return numSymbols; } private: - bool addValueAt(const Value *value, DenseMap *posMap, + bool addValueAt(Value *value, DenseMap *posMap, unsigned position) { auto it = posMap->find(value); if (it == posMap->end()) { @@ -186,8 +186,8 @@ private: } return false; } - unsigned getDimOrSymPos(const Value *value, - const DenseMap &dimPosMap, + unsigned getDimOrSymPos(Value *value, + const DenseMap &dimPosMap, unsigned dimPosOffset) const { auto it = dimPosMap.find(value); if (it != dimPosMap.end()) { @@ -201,9 +201,9 @@ private: unsigned numSrcDims = 0; unsigned numDstDims = 0; unsigned numSymbols = 0; - DenseMap srcDimPosMap; - DenseMap dstDimPosMap; - DenseMap symbolPosMap; + DenseMap srcDimPosMap; + DenseMap dstDimPosMap; + DenseMap symbolPosMap; }; // Builds a map from Value to identifier position in a new merged identifier @@ -451,7 +451,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, } // Add equality constraints for any operands that are defined by constant ops. - auto addEqForConstOperands = [&](ArrayRef operands) { + auto addEqForConstOperands = [&](ArrayRef operands) { for (unsigned i = 0, e = operands.size(); i < e; ++i) { if (isForInductionVar(operands[i])) continue; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index bc4c751dd77d..3de26589b123 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -677,7 +677,7 @@ LogicalResult FlatAffineConstraints::composeMap(AffineValueMap *vMap) { } // Turn a dimension into a symbol. -static void turnDimIntoSymbol(FlatAffineConstraints *cst, const Value &id) { +static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value &id) { unsigned pos; if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) { swapId(cst, pos, cst->getNumDimIds() - 1); @@ -686,7 +686,7 @@ static void turnDimIntoSymbol(FlatAffineConstraints *cst, const Value &id) { } // Turn a symbol into a dimension. -static void turnSymbolIntoDim(FlatAffineConstraints *cst, const Value &id) { +static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value &id) { unsigned pos; if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() && pos < cst->getNumDimAndSymbolIds()) { @@ -1669,7 +1669,7 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, if (localVarCst.getNumLocalIds() > 0) { // Set values for localVarCst. localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands); - for (const auto *operand : operands) { + for (auto *operand : operands) { unsigned pos; if (findId(*operand, &pos)) { if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) { @@ -1689,7 +1689,7 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, // this here since the constraint system changes after a bound is added. SmallVector positions; unsigned numOperands = operands.size(); - for (const auto *operand : operands) { + for (auto *operand : operands) { unsigned pos; if (!findId(*operand, &pos)) assert(0 && "expected to be found"); @@ -1859,7 +1859,7 @@ void FlatAffineConstraints::addLocalFloorDiv(ArrayRef dividend, addInequality(bound); } -bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const { +bool FlatAffineConstraints::findId(Value &id, unsigned *pos) const { unsigned i = 0; for (const auto &mayBeId : ids) { if (mayBeId.hasValue() && mayBeId.getValue() == &id) { @@ -1871,7 +1871,7 @@ bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const { return false; } -bool FlatAffineConstraints::containsId(const Value &id) const { +bool FlatAffineConstraints::containsId(Value &id) const { return llvm::any_of(ids, [&](const Optional &mayBeId) { return mayBeId.hasValue() && mayBeId.getValue() == &id; }); @@ -1896,7 +1896,7 @@ void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { /// Sets the specified identifer to a constant value; asserts if the id is not /// found. -void FlatAffineConstraints::setIdToConstant(const Value &id, int64_t val) { +void FlatAffineConstraints::setIdToConstant(Value &id, int64_t val) { unsigned pos; if (!findId(id, &pos)) // This is a pre-condition for this method. diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 50fb2586f7d0..84d0782f7d68 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -101,8 +101,7 @@ template class mlir::detail::DominanceInfoBase; //===----------------------------------------------------------------------===// /// Return true if instruction A properly dominates instruction B. -bool DominanceInfo::properlyDominates(const Instruction *a, - const Instruction *b) { +bool DominanceInfo::properlyDominates(Instruction *a, Instruction *b) { auto *aBlock = a->getBlock(), *bBlock = b->getBlock(); // If the blocks are the same, then check if b is before a in the block. @@ -122,7 +121,7 @@ bool DominanceInfo::properlyDominates(const Instruction *a, } /// Return true if value A properly dominates instruction B. -bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) { +bool DominanceInfo::properlyDominates(Value *a, Instruction *b) { if (auto *aInst = a->getDefiningInst()) return properlyDominates(aInst, b); @@ -136,8 +135,7 @@ bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) { //===----------------------------------------------------------------------===// /// Returns true if statement 'a' properly postdominates statement b. -bool PostDominanceInfo::properlyPostDominates(const Instruction *a, - const Instruction *b) { +bool PostDominanceInfo::properlyPostDominates(Instruction *a, Instruction *b) { auto *aBlock = a->getBlock(), *bBlock = b->getBlock(); // If the blocks are the same, check if b is before a in the block. @@ -145,7 +143,7 @@ bool PostDominanceInfo::properlyPostDominates(const Instruction *a, return b->isBeforeInBlock(a); // Traverse up b's hierarchy to check if b's block is contained in a's. - if (const auto *bAncestor = a->getBlock()->findAncestorInstInBlock(*b)) + if (auto *bAncestor = a->getBlock()->findAncestorInstInBlock(*b)) // Since we already know that aBlock != bBlock, here bAncestor != b. // a and bAncestor are in the same block; check if 'a' postdominates // bAncestor. diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 521dc5151e77..28b0f75909c6 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -179,7 +179,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(OpPointer forOp) { return gcd.getValue(); } -bool mlir::isAccessInvariant(const Value &iv, const Value &index) { +bool mlir::isAccessInvariant(Value &iv, Value &index) { assert(isForInductionVar(&iv) && "iv must be a AffineForOp"); assert(index.getType().isa() && "index must be of IndexType"); SmallVector affineApplyOps; @@ -203,10 +203,9 @@ bool mlir::isAccessInvariant(const Value &iv, const Value &index) { return !(AffineValueMap(composeOp).isFunctionOf(0, const_cast(&iv))); } -llvm::DenseSet -mlir::getInvariantAccesses(const Value &iv, - llvm::ArrayRef indices) { - llvm::DenseSet res; +llvm::DenseSet +mlir::getInvariantAccesses(Value &iv, llvm::ArrayRef indices) { + llvm::DenseSet res; for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) { auto *val = indices[idx]; if (isAccessInvariant(iv, *val)) { @@ -236,29 +235,29 @@ mlir::getInvariantAccesses(const Value &iv, /// // TODO(ntv): check strides. template -static bool isContiguousAccess(const Value &iv, const LoadOrStoreOp &memoryOp, +static bool isContiguousAccess(Value &iv, OpPointer memoryOp, unsigned fastestVaryingDim) { static_assert(std::is_same::value || std::is_same::value, "Must be called on either const LoadOp & or const StoreOp &"); - auto memRefType = memoryOp.getMemRefType(); + auto memRefType = memoryOp->getMemRefType(); if (fastestVaryingDim >= memRefType.getRank()) { - memoryOp.emitError("fastest varying dim out of bounds"); + memoryOp->emitError("fastest varying dim out of bounds"); return false; } auto layoutMap = memRefType.getAffineMaps(); // TODO(ntv): remove dependence on Builder once we support non-identity // layout map. - Builder b(memoryOp.getInstruction()->getContext()); + Builder b(memoryOp->getInstruction()->getContext()); if (layoutMap.size() >= 2 || (layoutMap.size() == 1 && !(layoutMap[0] == b.getMultiDimIdentityMap(layoutMap[0].getNumDims())))) { - return memoryOp.emitError("NYI: non-trivial layoutMap"), false; + return memoryOp->emitError("NYI: non-trivial layoutMap"), false; } - auto indices = memoryOp.getIndices(); + auto indices = memoryOp->getIndices(); auto numIndices = llvm::size(indices); unsigned d = 0; for (auto index : indices) { @@ -278,12 +277,12 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { return memRefType.getElementType().template isa(); } -static bool isVectorTransferReadOrWrite(const Instruction &inst) { +static bool isVectorTransferReadOrWrite(Instruction &inst) { return inst.isa() || inst.isa(); } using VectorizableInstFun = - std::function, const Instruction &)>; + std::function, Instruction &)>; static bool isVectorizableLoopWithCond(OpPointer loop, VectorizableInstFun isVectorizableInst) { @@ -302,7 +301,7 @@ static bool isVectorizableLoopWithCond(OpPointer loop, } // No vectorization across unknown regions. - auto regions = matcher::Op([](const Instruction &inst) -> bool { + auto regions = matcher::Op([](Instruction &inst) -> bool { return inst.getNumRegions() != 0 && !(inst.isa() || inst.isa()); }); @@ -342,22 +341,22 @@ static bool isVectorizableLoopWithCond(OpPointer loop, bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( OpPointer loop, unsigned fastestVaryingDim) { - VectorizableInstFun fun([fastestVaryingDim](OpPointer loop, - const Instruction &op) { - auto load = op.dyn_cast(); - auto store = op.dyn_cast(); - return load ? isContiguousAccess(*loop->getInductionVar(), *load, - fastestVaryingDim) - : isContiguousAccess(*loop->getInductionVar(), *store, - fastestVaryingDim); - }); + VectorizableInstFun fun( + [fastestVaryingDim](OpPointer loop, Instruction &op) { + auto load = op.dyn_cast(); + auto store = op.dyn_cast(); + return load ? isContiguousAccess(*loop->getInductionVar(), load, + fastestVaryingDim) + : isContiguousAccess(*loop->getInductionVar(), store, + fastestVaryingDim); + }); return isVectorizableLoopWithCond(loop, fun); } bool mlir::isVectorizableLoop(OpPointer loop) { VectorizableInstFun fun( // TODO: implement me - [](OpPointer loop, const Instruction &op) { return true; }); + [](OpPointer loop, Instruction &op) { return true; }); return isVectorizableLoopWithCond(loop, fun); } @@ -373,9 +372,9 @@ bool mlir::isInstwiseShiftValid(OpPointer forOp, // Work backwards over the body of the block so that the shift of a use's // ancestor instruction in the block gets recorded before it's looked up. - DenseMap forBodyShift; + DenseMap forBodyShift; for (auto it : llvm::enumerate(llvm::reverse(forBody->getInstructions()))) { - const auto &inst = it.value(); + auto &inst = it.value(); // Get the index of the current instruction, note that we are iterating in // reverse so we need to fix it up. @@ -387,7 +386,7 @@ bool mlir::isInstwiseShiftValid(OpPointer forOp, // Validate the results of this instruction if it were to be shifted. for (unsigned i = 0, e = inst.getNumResults(); i < e; ++i) { - const Value *result = inst.getResult(i); + Value *result = inst.getResult(i); for (const InstOperand &use : result->getUses()) { // If an ancestor instruction doesn't lie in the block of forOp, // there is no shift to check. diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 3e55291972b7..83b3591ce5cf 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -110,13 +110,9 @@ void NestedPattern::matchOne(Instruction *inst, } } -static bool isAffineForOp(const Instruction &inst) { - return inst.isa(); -} +static bool isAffineForOp(Instruction &inst) { return inst.isa(); } -static bool isAffineIfOp(const Instruction &inst) { - return inst.isa(); -} +static bool isAffineIfOp(Instruction &inst) { return inst.isa(); } namespace mlir { namespace matcher { @@ -129,7 +125,7 @@ NestedPattern If(NestedPattern child) { return NestedPattern(child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(child, [filter](const Instruction &inst) { + return NestedPattern(child, [filter](Instruction &inst) { return isAffineIfOp(inst) && filter(inst); }); } @@ -137,7 +133,7 @@ NestedPattern If(ArrayRef nested) { return NestedPattern(nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(nested, [filter](const Instruction &inst) { + return NestedPattern(nested, [filter](Instruction &inst) { return isAffineIfOp(inst) && filter(inst); }); } @@ -146,7 +142,7 @@ NestedPattern For(NestedPattern child) { return NestedPattern(child, isAffineForOp); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(child, [=](const Instruction &inst) { + return NestedPattern(child, [=](Instruction &inst) { return isAffineForOp(inst) && filter(inst); }); } @@ -154,24 +150,24 @@ NestedPattern For(ArrayRef nested) { return NestedPattern(nested, isAffineForOp); } NestedPattern For(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(nested, [=](const Instruction &inst) { + return NestedPattern(nested, [=](Instruction &inst) { return isAffineForOp(inst) && filter(inst); }); } // TODO(ntv): parallel annotation on loops. -bool isParallelLoop(const Instruction &inst) { +bool isParallelLoop(Instruction &inst) { auto loop = inst.cast(); return loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. -bool isReductionLoop(const Instruction &inst) { +bool isReductionLoop(Instruction &inst) { auto loop = inst.cast(); return loop || true; // loop->isReduction(); }; -bool isLoadOrStore(const Instruction &inst) { +bool isLoadOrStore(Instruction &inst) { return inst.isa() || inst.isa(); }; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 8918dd03f809..2cd0a83296b8 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -39,7 +39,7 @@ using llvm::SmallDenseMap; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. -void mlir::getLoopIVs(const Instruction &inst, +void mlir::getLoopIVs(Instruction &inst, SmallVectorImpl> *loops) { auto *currInst = inst.getParentInst(); OpPointer currAffineForOp; @@ -431,7 +431,7 @@ template LogicalResult mlir::boundCheckLoadOrStoreOp(OpPointer storeOp, // Returns in 'positions' the Block positions of 'inst' in each ancestor // Block from the Block containing instruction, stopping at 'limitBlock'. -static void findInstPosition(const Instruction *inst, Block *limitBlock, +static void findInstPosition(Instruction *inst, Block *limitBlock, SmallVectorImpl *positions) { Block *block = inst->getBlock(); while (block != limitBlock) { @@ -653,8 +653,8 @@ bool MemRefAccess::isStore() const { return opInst->isa(); } /// Returns the nesting depth of this statement, i.e., the number of loops /// surrounding this statement. -unsigned mlir::getNestingDepth(const Instruction &inst) { - const Instruction *currInst = &inst; +unsigned mlir::getNestingDepth(Instruction &inst) { + Instruction *currInst = &inst; unsigned depth = 0; while ((currInst = currInst->getParentInst())) { if (currInst->isa()) @@ -665,8 +665,7 @@ unsigned mlir::getNestingDepth(const Instruction &inst) { /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', /// where each lists loops from outer-most to inner-most in loop nest. -unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A, - const Instruction &B) { +unsigned mlir::getNumCommonSurroundingLoops(Instruction &A, Instruction &B) { SmallVector, 4> loopsA, loopsB; getLoopIVs(A, &loopsA); getLoopIVs(B, &loopsB); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 5ca3a829cbdf..5df31affe31f 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -180,7 +180,7 @@ AffineMap mlir::makePermutationMap( enclosingLoopToVectorDim); } -bool mlir::matcher::operatesOnSuperVectors(const Instruction &opInst, +bool mlir::matcher::operatesOnSuperVectors(Instruction &opInst, VectorType subVectorType) { // First, extract the vector type and ditinguish between: // a. ops that *must* lower a super-vector (i.e. vector_transfer_read, diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index d92aaedad179..b72731ed5cb1 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -52,7 +52,7 @@ namespace { /// class FuncVerifier { public: - bool failure(const Twine &message, const Instruction &value) { + bool failure(const Twine &message, Instruction &value) { return value.emitError(message); } @@ -108,9 +108,9 @@ public: bool verify(); bool verifyBlock(Block &block, bool isTopLevel); - bool verifyOperation(const Instruction &op); + bool verifyOperation(Instruction &op); bool verifyDominance(Block &block); - bool verifyInstDominance(const Instruction &inst); + bool verifyInstDominance(Instruction &inst); explicit FuncVerifier(Function &fn) : fn(fn), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {} @@ -270,12 +270,12 @@ bool FuncVerifier::verifyBlock(Block &block, bool isTopLevel) { } /// Check the invariants of the specified operation. -bool FuncVerifier::verifyOperation(const Instruction &op) { +bool FuncVerifier::verifyOperation(Instruction &op) { if (op.getFunction() != &fn) return failure("operation in the wrong function", op); // Check that operands are non-nil and structurally ok. - for (const auto *operand : op.getOperands()) { + for (auto *operand : op.getOperands()) { if (!operand) return failure("null operand found", op); @@ -322,7 +322,7 @@ bool FuncVerifier::verifyDominance(Block &block) { return false; } -bool FuncVerifier::verifyInstDominance(const Instruction &inst) { +bool FuncVerifier::verifyInstDominance(Instruction &inst) { // Check that operands properly dominate this use. for (unsigned operandNo = 0, e = inst.getNumOperands(); operandNo != e; ++operandNo) { diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index af172fcb542b..685a7a07a69c 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -184,7 +184,7 @@ static bool isSameShapedVectorOrTensor(Type type1, Type type2) { return false; } -bool OpTrait::impl::verifyCompatibleOperandBroadcast(const Instruction *op) { +bool OpTrait::impl::verifyCompatibleOperandBroadcast(Instruction *op) { assert(op->getNumOperands() == 2 && "only support broadcast check on two operands"); assert(op->getNumResults() == 1 && diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 202d254aee01..6430796bcc16 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -45,8 +45,8 @@ using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::detail; -static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) { - const auto *inst = v.getDefiningInst(); +static void printDefininingStatement(llvm::raw_ostream &os, Value &v) { + auto *inst = v.getDefiningInst(); if (inst) { inst->print(os); return; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index f4b49497cb26..b62a279fa29f 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -138,7 +138,7 @@ private: void recordTypeReference(Type ty) { usedTypes.insert(ty); } // Visit functions. - void visitInstruction(const Instruction *inst); + void visitInstruction(Instruction *inst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -189,7 +189,7 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitInstruction(const Instruction *inst) { +void ModuleState::visitInstruction(Instruction *inst) { // Visit all the types used in the operation. for (auto *operand : inst->getOperands()) visitType(operand->getType()); @@ -1060,11 +1060,11 @@ public: void printFunctionSignature(); // Methods to print instructions. - void print(const Instruction *inst); + void print(Instruction *inst); void print(Block *block, bool printBlockArgs = true); - void printOperation(const Instruction *op); - void printGenericOp(const Instruction *op); + void printOperation(Instruction *op); + void printGenericOp(Instruction *op); // Implement OpAsmPrinter. raw_ostream &getStream() const { return os; } @@ -1085,7 +1085,7 @@ public: void printFunctionReference(Function *func) { return ModulePrinter::printFunctionReference(func); } - void printOperand(const Value *value) { printValueID(value); } + void printOperand(Value *value) { printValueID(value); } void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}) { @@ -1107,8 +1107,7 @@ public: return it != blockIDs.end() ? it->second : ~0U; } - void printSuccessorAndUseList(const Instruction *term, - unsigned index) override; + void printSuccessorAndUseList(Instruction *term, unsigned index) override; /// Print a region. void printRegion(Region &blocks, bool printEntryBlockArgs) override { @@ -1127,17 +1126,17 @@ public: const static unsigned indentWidth = 2; protected: - void numberValueID(const Value *value); + void numberValueID(Value *value); void numberValuesInBlock(Block &block); - void printValueID(const Value *value, bool printResultNo = true) const; + void printValueID(Value *value, bool printResultNo = true) const; private: Function *function; /// This is the value ID for each SSA value in the current function. If this /// returns ~0, then the valueID has an entry in valueNames. - DenseMap valueIDs; - DenseMap valueNames; + DenseMap valueIDs; + DenseMap valueNames; /// This is the block ID for each block in the current function. DenseMap blockIDs; @@ -1191,7 +1190,7 @@ void FunctionPrinter::numberValuesInBlock(Block &block) { } } -void FunctionPrinter::numberValueID(const Value *value) { +void FunctionPrinter::numberValueID(Value *value) { assert(!valueIDs.count(value) && "Value numbered multiple times"); SmallString<32> specialNameBuffer; @@ -1389,14 +1388,13 @@ void FunctionPrinter::print(Block *block, bool printBlockArgs) { currentIndent -= indentWidth; } -void FunctionPrinter::print(const Instruction *inst) { +void FunctionPrinter::print(Instruction *inst) { os.indent(currentIndent); printOperation(inst); printTrailingLocation(inst->getLoc()); } -void FunctionPrinter::printValueID(const Value *value, - bool printResultNo) const { +void FunctionPrinter::printValueID(Value *value, bool printResultNo) const { int resultNo = -1; auto lookupValue = value; @@ -1434,7 +1432,7 @@ void FunctionPrinter::printValueID(const Value *value, os << '#' << resultNo; } -void FunctionPrinter::printOperation(const Instruction *op) { +void FunctionPrinter::printOperation(Instruction *op) { if (op->getNumResults()) { printValueID(op->getResult(0), /*printResultNo=*/false); os << " = "; @@ -1454,7 +1452,7 @@ void FunctionPrinter::printOperation(const Instruction *op) { printGenericOp(op); } -void FunctionPrinter::printGenericOp(const Instruction *op) { +void FunctionPrinter::printGenericOp(Instruction *op) { os << '"'; printEscapedString(op->getName().getStringRef(), os); os << "\"("; @@ -1465,11 +1463,10 @@ void FunctionPrinter::printGenericOp(const Instruction *op) { for (unsigned i = 0; i < numSuccessors; ++i) totalNumSuccessorOperands += op->getNumSuccessorOperands(i); unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands; - SmallVector properOperands( + SmallVector properOperands( op->operand_begin(), std::next(op->operand_begin(), numProperOperands)); - interleaveComma(properOperands, - [&](const Value *value) { printValueID(value); }); + interleaveComma(properOperands, [&](Value *value) { printValueID(value); }); os << ')'; @@ -1490,7 +1487,7 @@ void FunctionPrinter::printGenericOp(const Instruction *op) { // Print the type signature of the operation. os << " : ("; interleaveComma(properOperands, - [&](const Value *value) { printType(value->getType()); }); + [&](Value *value) { printType(value->getType()); }); os << ") -> "; if (op->getNumResults() == 1 && @@ -1499,7 +1496,7 @@ void FunctionPrinter::printGenericOp(const Instruction *op) { } else { os << '('; interleaveComma(op->getResults(), - [&](const Value *result) { printType(result->getType()); }); + [&](Value *result) { printType(result->getType()); }); os << ')'; } @@ -1508,7 +1505,7 @@ void FunctionPrinter::printGenericOp(const Instruction *op) { printRegion(region, /*printEntryBlockArgs=*/true); } -void FunctionPrinter::printSuccessorAndUseList(const Instruction *term, +void FunctionPrinter::printSuccessorAndUseList(Instruction *term, unsigned index) { printBlockName(term->getSuccessor(index)); @@ -1518,11 +1515,10 @@ void FunctionPrinter::printSuccessorAndUseList(const Instruction *term, os << '('; interleaveComma(succOperands, - [this](const Value *operand) { printValueID(operand); }); + [this](Value *operand) { printValueID(operand); }); os << " : "; - interleaveComma(succOperands, [this](const Value *operand) { - printType(operand->getType()); - }); + interleaveComma(succOperands, + [this](Value *operand) { printType(operand->getType()); }); os << ')'; } @@ -1585,7 +1581,7 @@ void IntegerSet::print(raw_ostream &os) const { ModulePrinter(os, state).printIntegerSet(*this); } -void Value::print(raw_ostream &os) const { +void Value::print(raw_ostream &os) { switch (getKind()) { case Value::Kind::BlockArgument: // TODO: Improve this. @@ -1596,9 +1592,9 @@ void Value::print(raw_ostream &os) const { } } -void Value::dump() const { print(llvm::errs()); } +void Value::dump() { print(llvm::errs()); } -void Instruction::print(raw_ostream &os) const { +void Instruction::print(raw_ostream &os) { auto *function = getFunction(); if (!function) { os << "<>\n"; @@ -1610,7 +1606,7 @@ void Instruction::print(raw_ostream &os) const { FunctionPrinter(function, modulePrinter).print(this); } -void Instruction::dump() const { +void Instruction::dump() { print(llvm::errs()); llvm::errs() << "\n"; } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 0470eb5e13bd..4782f92c5089 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -26,7 +26,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// /// Returns the number of this argument. -unsigned BlockArgument::getArgNumber() const { +unsigned BlockArgument::getArgNumber() { // Arguments are not stored in place, so we have to find it within the list. auto argList = getOwner()->getArguments(); return std::distance(argList.begin(), llvm::find(argList, this)); @@ -78,7 +78,7 @@ void Block::eraseFromFunction() { /// Returns 'inst' if 'inst' lies in this block, or otherwise finds the /// ancestor instruction of 'inst' that lies in this block. Returns nullptr if /// the latter fails. -Instruction *Block::findAncestorInstInBlock(const Instruction &inst) { +Instruction *Block::findAncestorInstInBlock(Instruction &inst) { // Traverse up the instruction hierarchy starting from the owner of operand to // find the ancestor instruction that resides in the block of 'forInst'. auto *currInst = const_cast(&inst); @@ -109,7 +109,7 @@ bool Block::verifyInstOrder() { std::next(instructions.begin()) == instructions.end()) return false; - const Instruction *prev = nullptr; + Instruction *prev = nullptr; for (auto &i : *this) { // The previous instruction must have a smaller order index than the next as // it appears earlier in the list. @@ -306,12 +306,12 @@ void Region::cloneInto(Region *dest, BlockAndValueMapping &mapper, // Clone the block arguments. The user might be deleting arguments to the // block by specifying them in the mapper. If so, we don't add the // argument to the cloned block. - for (const auto *arg : block.getArguments()) + for (auto *arg : block.getArguments()) if (!mapper.contains(arg)) mapper.map(arg, newBlock->addArgument(arg->getType())); // Clone and remap the instructions within this block. - for (const auto &inst : block) + for (auto &inst : block) newBlock->push_back(inst.clone(mapper, context)); dest->push_back(newBlock); diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 4ebf2a798a24..698b4cd19265 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -33,7 +33,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// /// Return the result number of this result. -unsigned InstResult::getResultNumber() const { +unsigned InstResult::getResultNumber() { // Results are always stored consecutively, so use pointer subtraction to // figure out what number this is. return this - &getOwner()->getInstResults()[0]; @@ -44,7 +44,7 @@ unsigned InstResult::getResultNumber() const { //===----------------------------------------------------------------------===// /// Return which operand this is in the operand list. -template <> unsigned InstOperand::getOperandNumber() const { +template <> unsigned InstOperand::getOperandNumber() { return this - &getOwner()->getInstOperands()[0]; } @@ -53,7 +53,7 @@ template <> unsigned InstOperand::getOperandNumber() const { //===----------------------------------------------------------------------===// /// Return which operand this is in the operand list. -template <> unsigned BlockOperand::getOperandNumber() const { +template <> unsigned BlockOperand::getOperandNumber() { return this - &getOwner()->getBlockOperands()[0]; } @@ -287,7 +287,7 @@ void Instruction::destroy() { } /// Return the context this operation is associated with. -MLIRContext *Instruction::getContext() const { +MLIRContext *Instruction::getContext() { // If we have a result or operand type, that is a constant time way to get // to the context. if (getNumResults()) @@ -300,11 +300,11 @@ MLIRContext *Instruction::getContext() const { return getFunction()->getContext(); } -Instruction *Instruction::getParentInst() const { +Instruction *Instruction::getParentInst() { return block ? block->getContainingInst() : nullptr; } -Function *Instruction::getFunction() const { +Function *Instruction::getFunction() { return block ? block->getFunction() : nullptr; } @@ -339,14 +339,14 @@ void Instruction::walkPostOrder( /// Emit a note about this instruction, reporting up to any diagnostic /// handlers that may be listening. -void Instruction::emitNote(const Twine &message) const { +void Instruction::emitNote(const Twine &message) { getContext()->emitDiagnostic(getLoc(), message, MLIRContext::DiagnosticKind::Note); } /// Emit a warning about this instruction, reporting up to any diagnostic /// handlers that may be listening. -void Instruction::emitWarning(const Twine &message) const { +void Instruction::emitWarning(const Twine &message) { getContext()->emitDiagnostic(getLoc(), message, MLIRContext::DiagnosticKind::Warning); } @@ -355,7 +355,7 @@ void Instruction::emitWarning(const Twine &message) const { /// any diagnostic handlers that may be listening. This function always /// returns true. NOTE: This may terminate the containing application, only /// use when the IR is in an inconsistent state. -bool Instruction::emitError(const Twine &message) const { +bool Instruction::emitError(const Twine &message) { return getContext()->emitError(getLoc(), message); } @@ -364,7 +364,7 @@ bool Instruction::emitError(const Twine &message) const { /// of the parent block. /// Note: This function has an average complexity of O(1), but worst case may /// take O(N) where N is the number of instructions within the parent block. -bool Instruction::isBeforeInBlock(const Instruction *other) const { +bool Instruction::isBeforeInBlock(Instruction *other) { assert(block && "Instructions without parent blocks have no order."); assert(other && other->block == block && "Expected other instruction to have the same parent block."); @@ -490,7 +490,7 @@ void Instruction::dropAllReferences() { } /// Return true if there are no users of any results of this operation. -bool Instruction::use_empty() const { +bool Instruction::use_empty() { for (auto *result : getResults()) if (!result->use_empty()) return false; @@ -502,10 +502,6 @@ void Instruction::setSuccessor(Block *block, unsigned index) { getBlockOperands()[index].set(block); } -auto Instruction::getNonSuccessorOperands() const -> const_operand_range { - return {const_operand_iterator(this, 0), - const_operand_iterator(this, getSuccessorOperandIndex(0))}; -} auto Instruction::getNonSuccessorOperands() -> operand_range { return {operand_iterator(this, 0), operand_iterator(this, getSuccessorOperandIndex(0))}; @@ -513,7 +509,7 @@ auto Instruction::getNonSuccessorOperands() -> operand_range { /// Get the index of the first operand of the successor at the provided /// index. -unsigned Instruction::getSuccessorOperandIndex(unsigned index) const { +unsigned Instruction::getSuccessorOperandIndex(unsigned index) { assert(!isKnownNonTerminator() && "only terminators may have successors"); assert(index < getNumSuccessors()); @@ -527,13 +523,6 @@ unsigned Instruction::getSuccessorOperandIndex(unsigned index) const { return getNumOperands() - postSuccessorOpCount; } -auto Instruction::getSuccessorOperands(unsigned index) const - -> const_operand_range { - unsigned succOperandIndex = getSuccessorOperandIndex(index); - return {const_operand_iterator(this, succOperandIndex), - const_operand_iterator(this, succOperandIndex + - getNumSuccessorOperands(index))}; -} auto Instruction::getSuccessorOperands(unsigned index) -> operand_range { unsigned succOperandIndex = getSuccessorOperandIndex(index); return {operand_iterator(this, succOperandIndex), @@ -544,19 +533,16 @@ auto Instruction::getSuccessorOperands(unsigned index) -> operand_range { /// Attempt to constant fold this operation with the specified constant /// operand values. If successful, this fills in the results vector. If not, /// results is unspecified. -LogicalResult -Instruction::constantFold(ArrayRef operands, - SmallVectorImpl &results) const { - auto *inst = const_cast(this); - +LogicalResult Instruction::constantFold(ArrayRef operands, + SmallVectorImpl &results) { if (auto *abstractOp = getAbstractOperation()) { // If we have a registered operation definition matching this one, use it to // try to constant fold the operation. - if (succeeded(abstractOp->constantFoldHook(inst, operands, results))) + if (succeeded(abstractOp->constantFoldHook(this, operands, results))) return success(); // Otherwise, fall back on the dialect hook to handle it. - return abstractOp->dialect.constantFoldHook(inst, operands, results); + return abstractOp->dialect.constantFoldHook(this, operands, results); } // If this operation hasn't been registered or doesn't have abstract @@ -564,7 +550,7 @@ Instruction::constantFold(ArrayRef operands, auto opName = getName().getStringRef(); auto dialectPrefix = opName.split('.').first; if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix)) - return dialect->constantFoldHook(inst, operands, results); + return dialect->constantFoldHook(this, operands, results); return failure(); } @@ -582,7 +568,7 @@ LogicalResult Instruction::fold(SmallVectorImpl &results) { /// Emit an error with the op name prefixed, like "'dim' op " which is /// convenient for verifiers. -bool Instruction::emitOpError(const Twine &message) const { +bool Instruction::emitOpError(const Twine &message) { return emitError(Twine('\'') + getName().getStringRef() + "' op " + message); } @@ -596,7 +582,7 @@ bool Instruction::emitOpError(const Twine &message) const { /// sub-instructions to the corresponding instruction that is copied, and adds /// those mappings to the map. Instruction *Instruction::clone(BlockAndValueMapping &mapper, - MLIRContext *context) const { + MLIRContext *context) { SmallVector operands; SmallVector successors; @@ -605,7 +591,7 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, if (getNumSuccessors() == 0) { // Non-branching operations can just add all the operands. for (auto *opValue : getOperands()) - operands.push_back(mapper.lookupOrDefault(const_cast(opValue))); + operands.push_back(mapper.lookupOrDefault(opValue)); } else { // We add the operands separated by nullptr's for each successor. unsigned firstSuccOperand = @@ -614,21 +600,18 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, unsigned i = 0; for (; i != firstSuccOperand; ++i) - operands.push_back( - mapper.lookupOrDefault(const_cast(InstOperands[i].get()))); + operands.push_back(mapper.lookupOrDefault(InstOperands[i].get())); successors.reserve(getNumSuccessors()); for (unsigned succ = 0, e = getNumSuccessors(); succ != e; ++succ) { - successors.push_back( - mapper.lookupOrDefault(const_cast(getSuccessor(succ)))); + successors.push_back(mapper.lookupOrDefault(getSuccessor(succ))); // Add sentinel to delineate successor operands. operands.push_back(nullptr); // Remap the successors operands. for (auto *operand : getSuccessorOperands(succ)) - operands.push_back( - mapper.lookupOrDefault(const_cast(operand))); + operands.push_back(mapper.lookupOrDefault(operand)); } } @@ -652,7 +635,7 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, return newOp; } -Instruction *Instruction::clone(MLIRContext *context) const { +Instruction *Instruction::clone(MLIRContext *context) { BlockAndValueMapping mapper; return clone(mapper, context); } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index a2605a9f9103..78cc18480cfe 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -93,20 +93,19 @@ void OpState::emitNote(const Twine &message) const { // Op Trait implementations //===----------------------------------------------------------------------===// -bool OpTrait::impl::verifyZeroOperands(const Instruction *op) { +bool OpTrait::impl::verifyZeroOperands(Instruction *op) { if (op->getNumOperands() != 0) return op->emitOpError("requires zero operands"); return false; } -bool OpTrait::impl::verifyOneOperand(const Instruction *op) { +bool OpTrait::impl::verifyOneOperand(Instruction *op) { if (op->getNumOperands() != 1) return op->emitOpError("requires a single operand"); return false; } -bool OpTrait::impl::verifyNOperands(const Instruction *op, - unsigned numOperands) { +bool OpTrait::impl::verifyNOperands(Instruction *op, unsigned numOperands) { if (op->getNumOperands() != numOperands) { return op->emitOpError("expected " + Twine(numOperands) + " operands, but found " + @@ -115,7 +114,7 @@ bool OpTrait::impl::verifyNOperands(const Instruction *op, return false; } -bool OpTrait::impl::verifyAtLeastNOperands(const Instruction *op, +bool OpTrait::impl::verifyAtLeastNOperands(Instruction *op, unsigned numOperands) { if (op->getNumOperands() < numOperands) return op->emitOpError("expected " + Twine(numOperands) + @@ -135,7 +134,7 @@ static Type getTensorOrVectorElementType(Type type) { return type; } -bool OpTrait::impl::verifyOperandsAreIntegerLike(const Instruction *op) { +bool OpTrait::impl::verifyOperandsAreIntegerLike(Instruction *op) { for (auto *operand : op->getOperands()) { auto type = getTensorOrVectorElementType(operand->getType()); if (!type.isIntOrIndex()) @@ -144,7 +143,7 @@ bool OpTrait::impl::verifyOperandsAreIntegerLike(const Instruction *op) { return false; } -bool OpTrait::impl::verifySameTypeOperands(const Instruction *op) { +bool OpTrait::impl::verifySameTypeOperands(Instruction *op) { // Zero or one operand always have the "same" type. unsigned nOperands = op->getNumOperands(); if (nOperands < 2) @@ -158,26 +157,25 @@ bool OpTrait::impl::verifySameTypeOperands(const Instruction *op) { return false; } -bool OpTrait::impl::verifyZeroResult(const Instruction *op) { +bool OpTrait::impl::verifyZeroResult(Instruction *op) { if (op->getNumResults() != 0) return op->emitOpError("requires zero results"); return false; } -bool OpTrait::impl::verifyOneResult(const Instruction *op) { +bool OpTrait::impl::verifyOneResult(Instruction *op) { if (op->getNumResults() != 1) return op->emitOpError("requires one result"); return false; } -bool OpTrait::impl::verifyNResults(const Instruction *op, - unsigned numOperands) { +bool OpTrait::impl::verifyNResults(Instruction *op, unsigned numOperands) { if (op->getNumResults() != numOperands) return op->emitOpError("expected " + Twine(numOperands) + " results"); return false; } -bool OpTrait::impl::verifyAtLeastNResults(const Instruction *op, +bool OpTrait::impl::verifyAtLeastNResults(Instruction *op, unsigned numOperands) { if (op->getNumResults() < numOperands) return op->emitOpError("expected " + Twine(numOperands) + @@ -206,7 +204,7 @@ static bool verifyShapeMatch(Type type1, Type type2) { return false; } -bool OpTrait::impl::verifySameOperandsAndResultShape(const Instruction *op) { +bool OpTrait::impl::verifySameOperandsAndResultShape(Instruction *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return true; @@ -224,7 +222,7 @@ bool OpTrait::impl::verifySameOperandsAndResultShape(const Instruction *op) { return false; } -bool OpTrait::impl::verifySameOperandsAndResultType(const Instruction *op) { +bool OpTrait::impl::verifySameOperandsAndResultType(Instruction *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return true; @@ -242,9 +240,9 @@ bool OpTrait::impl::verifySameOperandsAndResultType(const Instruction *op) { return false; } -static bool verifyBBArguments( - llvm::iterator_range operands, - Block *destBB, const Instruction *op) { +static bool +verifyBBArguments(llvm::iterator_range operands, + Block *destBB, Instruction *op) { unsigned operandCount = std::distance(operands.begin(), operands.end()); if (operandCount != destBB->getNumArguments()) return op->emitError("branch has " + Twine(operandCount) + @@ -260,7 +258,7 @@ static bool verifyBBArguments( return false; } -static bool verifyTerminatorSuccessors(const Instruction *op) { +static bool verifyTerminatorSuccessors(Instruction *op) { // Verify that the operands lines up with the BB arguments in the successor. Function *fn = op->getFunction(); for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { @@ -273,7 +271,7 @@ static bool verifyTerminatorSuccessors(const Instruction *op) { return false; } -bool OpTrait::impl::verifyIsTerminator(const Instruction *op) { +bool OpTrait::impl::verifyIsTerminator(Instruction *op) { Block *block = op->getBlock(); // Verify that the operation is at the end of the respective parent block. if (!block || &block->back() != op) @@ -285,7 +283,7 @@ bool OpTrait::impl::verifyIsTerminator(const Instruction *op) { return false; } -bool OpTrait::impl::verifyResultsAreBoolLike(const Instruction *op) { +bool OpTrait::impl::verifyResultsAreBoolLike(Instruction *op) { for (auto *result : op->getResults()) { auto elementType = getTensorOrVectorElementType(result->getType()); bool isBoolType = elementType.isInteger(1); @@ -296,7 +294,7 @@ bool OpTrait::impl::verifyResultsAreBoolLike(const Instruction *op) { return false; } -bool OpTrait::impl::verifyResultsAreFloatLike(const Instruction *op) { +bool OpTrait::impl::verifyResultsAreFloatLike(Instruction *op) { for (auto *result : op->getResults()) { if (!getTensorOrVectorElementType(result->getType()).isa()) return op->emitOpError("requires a floating point type"); @@ -305,7 +303,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Instruction *op) { return false; } -bool OpTrait::impl::verifyResultsAreIntegerLike(const Instruction *op) { +bool OpTrait::impl::verifyResultsAreIntegerLike(Instruction *op) { for (auto *result : op->getResults()) { auto type = getTensorOrVectorElementType(result->getType()); if (!type.isIntOrIndex()) @@ -338,7 +336,7 @@ bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(type, result->types); } -void impl::printBinaryOp(const Instruction *op, OpAsmPrinter *p) { +void impl::printBinaryOp(Instruction *op, OpAsmPrinter *p) { assert(op->getNumOperands() == 2 && "binary op should have two operands"); assert(op->getNumResults() == 1 && "binary op should have one result"); @@ -377,7 +375,7 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(dstType, result->types); } -void impl::printCastOp(const Instruction *op, OpAsmPrinter *p) { +void impl::printCastOp(Instruction *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType(); } diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 2b6eea80a4aa..6ac1711229c0 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -29,7 +29,7 @@ Instruction *Value::getDefiningInst() { } /// Return the function that this Value is defined in. -Function *Value::getFunction() const { +Function *Value::getFunction() { switch (getKind()) { case Value::Kind::BlockArgument: return cast(this)->getFunction(); @@ -64,14 +64,14 @@ void IRObjectWithUseList::dropAllUses() { //===----------------------------------------------------------------------===// /// Return the function that this argument is defined in. -Function *BlockArgument::getFunction() const { +Function *BlockArgument::getFunction() { if (auto *owner = getOwner()) return owner->getFunction(); return nullptr; } /// Returns if the current argument is a function argument. -bool BlockArgument::isFunctionArgument() const { +bool BlockArgument::isFunctionArgument() { auto *containingFn = getFunction(); return containingFn && &containingFn->front() == getOwner(); } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 55da9c6ed6b2..963362871a2d 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -36,7 +36,7 @@ using namespace mlir; /// A custom binary operation printer that omits the "std." prefix from the /// operation names. -void detail::printStandardBinaryOp(const Instruction *op, OpAsmPrinter *p) { +void detail::printStandardBinaryOp(Instruction *op, OpAsmPrinter *p) { assert(op->getNumOperands() == 2 && "binary op should have two operands"); assert(op->getNumResults() == 1 && "binary op should have one result"); @@ -68,8 +68,8 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context) >(); } -void mlir::printDimAndSymbolList(Instruction::const_operand_iterator begin, - Instruction::const_operand_iterator end, +void mlir::printDimAndSymbolList(Instruction::operand_iterator begin, + Instruction::operand_iterator end, unsigned numDims, OpAsmPrinter *p) { *p << '('; p->printOperands(begin, begin + numDims); @@ -1803,8 +1803,7 @@ void ReturnOp::print(OpAsmPrinter *p) { *p << " : "; interleave( operand_begin(), operand_end(), - [&](const Value *e) { p->printType(e->getType()); }, - [&]() { *p << ", "; }); + [&](Value *e) { p->printType(e->getType()); }, [&]() { *p << ", "; }); } } diff --git a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp index 0320e7823244..1e0c01a5df13 100644 --- a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp +++ b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp @@ -92,13 +92,6 @@ VectorTransferReadOp::getIndices() { return {begin, end}; } -llvm::iterator_range -VectorTransferReadOp::getIndices() const { - auto begin = getInstruction()->operand_begin() + Offsets::FirstIndexOffset; - auto end = begin + getMemRefType().getRank(); - return {begin, end}; -} - Optional VectorTransferReadOp::getPaddingValue() { auto memRefRank = getMemRefType().getRank(); if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { @@ -107,16 +100,7 @@ Optional VectorTransferReadOp::getPaddingValue() { return Optional(getOperand(Offsets::FirstIndexOffset + memRefRank)); } -Optional VectorTransferReadOp::getPaddingValue() const { - auto memRefRank = getMemRefType().getRank(); - if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { - return None; - } - return Optional( - getOperand(Offsets::FirstIndexOffset + memRefRank)); -} - -AffineMap VectorTransferReadOp::getPermutationMap() const { +AffineMap VectorTransferReadOp::getPermutationMap() { return getAttrOfType(getPermutationMapAttrName()).getValue(); } @@ -134,7 +118,7 @@ void VectorTransferReadOp::print(OpAsmPrinter *p) { // Construct the FunctionType and print it. llvm::SmallVector inputs{getMemRefType()}; // Must have at least one actual index, see verify. - const Value *firstIndex = *(getIndices().begin()); + Value *firstIndex = *getIndices().begin(); Type indexType = firstIndex->getType(); inputs.append(getMemRefType().getRank(), indexType); if (optionalPaddingValue) { @@ -309,14 +293,7 @@ VectorTransferWriteOp::getIndices() { return {begin, end}; } -llvm::iterator_range -VectorTransferWriteOp::getIndices() const { - auto begin = getInstruction()->operand_begin() + Offsets::FirstIndexOffset; - auto end = begin + getMemRefType().getRank(); - return {begin, end}; -} - -AffineMap VectorTransferWriteOp::getPermutationMap() const { +AffineMap VectorTransferWriteOp::getPermutationMap() { return getAttrOfType(getPermutationMapAttrName()).getValue(); } diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 7c74c2fb2f65..76d484ac4025 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -59,7 +59,7 @@ private: bool convertOneFunction(Function &func); void connectPHINodes(Function &func); bool convertBlock(Block &bb, bool ignoreArguments); - bool convertInstruction(const Instruction &inst, llvm::IRBuilder<> &builder); + bool convertInstruction(Instruction &inst, llvm::IRBuilder<> &builder); template SmallVector lookupValues(Range &&values); @@ -73,7 +73,7 @@ private: // Mappings between original and translated values, used for lookups. llvm::DenseMap functionMapping; - llvm::DenseMap valueMapping; + llvm::DenseMap valueMapping; llvm::DenseMap blockMapping; }; } // end anonymous namespace @@ -185,7 +185,7 @@ template SmallVector ModuleTranslation::lookupValues(Range &&values) { SmallVector remapped; remapped.reserve(llvm::size(values)); - for (const Value *v : values) { + for (Value *v : values) { remapped.push_back(valueMapping.lookup(v)); } return remapped; @@ -195,7 +195,7 @@ SmallVector ModuleTranslation::lookupValues(Range &&values) { // using the `builder`. LLVM IR Builder does not have a generic interface so // this has to be a long chain of `if`s calling different functions with a // different number of arguments. -bool ModuleTranslation::convertInstruction(const Instruction &inst, +bool ModuleTranslation::convertInstruction(Instruction &inst, llvm::IRBuilder<> &builder) { auto extractPosition = [](ArrayAttr attr) { SmallVector position; @@ -212,8 +212,7 @@ bool ModuleTranslation::convertInstruction(const Instruction &inst, // itself. Otherwise, this is an indirect call and the callee is the first // operand, look it up as a normal value. Return the llvm::Value representing // the function result, which may be of llvm::VoidTy type. - auto convertCall = [this, - &builder](const Instruction &inst) -> llvm::Value * { + auto convertCall = [this, &builder](Instruction &inst) -> llvm::Value * { auto operands = lookupValues(inst.getOperands()); ArrayRef operandsRef(operands); if (auto attr = inst.getAttrOfType("callee")) { @@ -270,7 +269,7 @@ bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { auto predecessors = bb.getPredecessors(); unsigned numPredecessors = std::distance(predecessors.begin(), predecessors.end()); - for (const auto *arg : bb.getArguments()) { + for (auto *arg : bb.getArguments()) { auto wrappedType = arg->getType().dyn_cast(); if (!wrappedType) { arg->getType().getContext()->emitError( @@ -284,7 +283,7 @@ bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { } // Traverse instructions. - for (const auto &inst : bb) { + for (auto &inst : bb) { if (convertInstruction(inst, builder)) return true; } @@ -294,8 +293,8 @@ bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { // Get the SSA value passed to the current block from the terminator instruction // of its predecessor. -static const Value *getPHISourceValue(Block *current, Block *pred, - unsigned numArguments, unsigned index) { +static Value *getPHISourceValue(Block *current, Block *pred, + unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); if (terminator.isa()) { return terminator.getOperand(index); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 31f4d48e4ed8..05760f187611 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -39,7 +39,8 @@ using namespace mlir; namespace { // TODO(riverriddle) Handle commutative operations. struct SimpleOperationInfo : public llvm::DenseMapInfo { - static unsigned getHashValue(const Instruction *op) { + static unsigned getHashValue(const Instruction *opC) { + auto *op = const_cast(opC); // Hash the operations based upon their: // - Instruction Name // - Attributes @@ -50,7 +51,9 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo { hash_combine_range(op->result_type_begin(), op->result_type_end()), hash_combine_range(op->operand_begin(), op->operand_end())); } - static bool isEqual(const Instruction *lhs, const Instruction *rhs) { + static bool isEqual(const Instruction *lhsC, const Instruction *rhsC) { + auto *lhs = const_cast(lhsC); + auto *rhs = const_cast(rhsC); if (lhs == rhs) return true; if (lhs == getTombstoneKey() || lhs == getEmptyKey() || diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index b033dadfe516..a659b2e480b4 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -49,9 +49,8 @@ private: // Utility that looks up a list of value in the value remapping table. Returns // an empty vector if one of the values is not mapped yet. - SmallVector - lookupValues(const llvm::iterator_range - &operands); + SmallVector lookupValues( + const llvm::iterator_range &operands); // Converts the given function to the dialect using hooks defined in // `dialectConversion`. Returns the converted function or `nullptr` on error. @@ -102,10 +101,10 @@ private: } // end namespace mlir SmallVector impl::FunctionConversion::lookupValues( - const llvm::iterator_range &operands) { + const llvm::iterator_range &operands) { SmallVector remapped; remapped.reserve(llvm::size(operands)); - for (const Value *operand : operands) { + for (Value *operand : operands) { Value *value = mapping.lookupOrNull(operand); if (!value) return {}; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index d97538734d1e..954135d2a4ff 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -477,7 +477,7 @@ bool DmaGeneration::runOnBlock(Block *block) { // Get to the first load, store, or for op. auto curBegin = - std::find_if(block->begin(), block->end(), [&](const Instruction &inst) { + std::find_if(block->begin(), block->end(), [&](Instruction &inst) { return inst.isa() || inst.isa() || inst.isa(); }); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 6d4ea7206b7a..95bdc3ca2d22 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -142,7 +142,7 @@ struct LoopNestStateCollector { }; // TODO(b/117228571) Replace when this is modeled through side-effects/op traits -static bool isMemRefDereferencingOp(const Instruction &op) { +static bool isMemRefDereferencingOp(Instruction &op) { if (op.isa() || op.isa() || op.isa() || op.isa()) return true; diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 804991a7b8bc..6208eee5d626 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -192,7 +192,7 @@ struct MaterializationState { VectorType superVectorType; VectorType hwVectorType; SmallVector hwVectorInstance; - DenseMap *substitutionsMap; + DenseMap *substitutionsMap; }; struct MaterializeVectorsPass : public FunctionPass { @@ -239,9 +239,9 @@ static SmallVector delinearize(unsigned linearIndex, return res; } -static Instruction * -instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, - DenseMap *substitutionsMap); +static Instruction *instantiate(FuncBuilder *b, Instruction *opInst, + VectorType hwVectorType, + DenseMap *substitutionsMap); /// Not all Values belong to a program slice scoped within the immediately /// enclosing loop. @@ -253,7 +253,7 @@ instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, /// /// If substitution fails, returns nullptr. static Value *substitute(Value *v, VectorType hwVectorType, - DenseMap *substitutionsMap) { + DenseMap *substitutionsMap) { auto it = substitutionsMap->find(v); if (it == substitutionsMap->end()) { auto *opInst = v->getDefiningInst(); @@ -404,9 +404,9 @@ materializeAttributes(Instruction *opInst, VectorType hwVectorType) { /// substitutionsMap. /// /// If the underlying substitution fails, this fails too and returns nullptr. -static Instruction * -instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, - DenseMap *substitutionsMap) { +static Instruction *instantiate(FuncBuilder *b, Instruction *opInst, + VectorType hwVectorType, + DenseMap *substitutionsMap) { assert(!opInst->isa() && "Should call the function specialized for VectorTransferReadOp"); assert(!opInst->isa() && @@ -481,10 +481,10 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, /// `hwVectorType` int the covering of the super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static Instruction * -instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, - ArrayRef hwVectorInstance, - DenseMap *substitutionsMap) { +static Instruction *instantiate(FuncBuilder *b, VectorTransferReadOp *read, + VectorType hwVectorType, + ArrayRef hwVectorInstance, + DenseMap *substitutionsMap) { SmallVector indices = map(makePtrDynCaster(), read->getIndices()); auto affineIndices = @@ -505,10 +505,10 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, /// `hwVectorType` int the covering of th3e super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static Instruction * -instantiate(FuncBuilder *b, VectorTransferWriteOp *write, - VectorType hwVectorType, ArrayRef hwVectorInstance, - DenseMap *substitutionsMap) { +static Instruction *instantiate(FuncBuilder *b, VectorTransferWriteOp *write, + VectorType hwVectorType, + ArrayRef hwVectorInstance, + DenseMap *substitutionsMap) { SmallVector indices = map(makePtrDynCaster(), write->getIndices()); auto affineIndices = @@ -624,7 +624,7 @@ static bool emitSlice(MaterializationState *state, // Fresh RAII instanceIndices and substitutionsMap. MaterializationState scopedState = *state; scopedState.hwVectorInstance = delinearize(idx, *ratio); - DenseMap substitutionMap; + DenseMap substitutionMap; scopedState.substitutionsMap = &substitutionMap; // slice are topologically sorted, we can just clone them in order. for (auto *inst : *slice) { @@ -749,7 +749,7 @@ void MaterializeVectorsPass::runOnFunction() { // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. - auto filter = [subVectorType](const Instruction &inst) { + auto filter = [subVectorType](Instruction &inst) { if (!inst.isa()) { return false; } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 97532fdbe948..1dfc4e7dc172 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -56,7 +56,7 @@ FunctionPassBase *mlir::createPipelineDataTransferPass() { // Returns the position of the tag memref operand given a DMA instruction. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getTagMemRefPos(const Instruction &dmaInst) { +static unsigned getTagMemRefPos(Instruction &dmaInst) { assert(dmaInst.isa() || dmaInst.isa()); if (dmaInst.isa()) { // Second to last operand. @@ -323,7 +323,7 @@ void PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { findMatchingStartFinishInsts(forOp, startWaitPairs); // Store shift for instruction for later lookup for AffineApplyOp's. - DenseMap instShiftMap; + DenseMap instShiftMap; for (auto &pair : startWaitPairs) { auto *dmaStartInst = pair.first; assert(dmaStartInst->isa()); @@ -341,13 +341,13 @@ void PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { SmallVector affineApplyInsts; SmallVector operands(dmaStartInst->getOperands()); getReachableAffineApplyOps(operands, affineApplyInsts); - for (const auto *inst : affineApplyInsts) { + for (auto *inst : affineApplyInsts) { instShiftMap[inst] = 0; } } } // Everything else (including compute ops and dma finish) are shifted by one. - for (const auto &inst : *forOp->getBody()) { + for (auto &inst : *forOp->getBody()) { if (instShiftMap.find(&inst) == instShiftMap.end()) { instShiftMap[&inst] = 1; } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index cbf68056eb91..2f10b898502f 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -37,19 +37,19 @@ using namespace mlir; /// Return true if this operation dereferences one or more memref's. // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) -static bool isMemRefDereferencingOp(const Instruction &op) { +static bool isMemRefDereferencingOp(Instruction &op) { if (op.isa() || op.isa() || op.isa() || op.isa()) return true; return false; } -bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, +bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, - const Instruction *domInstFilter, - const Instruction *postDomInstFilter) { + Instruction *domInstFilter, + Instruction *postDomInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); @@ -167,7 +167,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // Result types don't change. Both memref's are of the same elemental type. state.types.reserve(opInst->getNumResults()); - for (const auto *result : opInst->getResults()) + for (auto *result : opInst->getResults()) state.types.push_back(result->getType()); // Attributes also do not change. diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index af6fc581cfdf..9c9f8593f318 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -105,7 +105,7 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { VectorType::get(shape, FloatType::getF32(f->getContext())); // Only filter instructions that operate on a strict super-vector and have one // return. This makes testing easier. - auto filter = [subVectorType](const Instruction &inst) { + auto filter = [subVectorType](Instruction &inst) { assert(subVectorType.getElementType() == FloatType::getF32(subVectorType.getContext()) && "Only f32 supported for now"); @@ -150,7 +150,7 @@ static NestedPattern patternTestSlicingOps() { using functional::map; using matcher::Op; // Match all OpInstructions with the kTestSlicingOpName name. - auto filter = [](const Instruction &inst) { + auto filter = [](Instruction &inst) { return inst.getName().getStringRef() == kTestSlicingOpName; }; return Op(filter); @@ -199,7 +199,7 @@ void VectorizerTestPass::testSlicing(Function *f) { } } -static bool customOpWithAffineMapAttribute(const Instruction &inst) { +static bool customOpWithAffineMapAttribute(Instruction &inst) { return inst.getName().getStringRef() == VectorizerTestPass::kTestAffineMapOpName; } @@ -225,11 +225,11 @@ void VectorizerTestPass::testComposeMaps(Function *f) { simplifyAffineMap(res).print(outs() << "\nComposed map: "); } -static bool affineApplyOp(const Instruction &inst) { +static bool affineApplyOp(Instruction &inst) { return inst.isa(); } -static bool singleResultAffineApplyOpWithoutUses(const Instruction &inst) { +static bool singleResultAffineApplyOpWithoutUses(Instruction &inst) { auto app = inst.dyn_cast(); return app && app->use_empty(); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 5c5045b668d3..1834b2db0cf0 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -734,7 +734,7 @@ struct VectorizationState { // Map of old scalar Instruction to new vectorized Instruction. DenseMap vectorizationMap; // Map of old scalar Value to new vectorized Value. - DenseMap replacementMap; + DenseMap replacementMap; // The strategy drives which loop to vectorize by which amount. const VectorizationStrategy *strategy; // Use-def roots. These represent the starting points for the worklist in the @@ -755,7 +755,7 @@ struct VectorizationState { void registerTerminal(Instruction *inst); private: - void registerReplacement(const Value *key, Value *value); + void registerReplacement(Value *key, Value *value); }; } // end namespace @@ -796,7 +796,7 @@ void VectorizationState::finishVectorizationPattern() { } } -void VectorizationState::registerReplacement(const Value *key, Value *value) { +void VectorizationState::registerReplacement(Value *key, Value *value) { assert(replacementMap.count(key) == 0 && "replacement already registered"); replacementMap.insert(std::make_pair(key, value)); } @@ -858,8 +858,7 @@ static LogicalResult vectorizeAffineForOp(AffineForOp *loop, int64_t step, using namespace functional; loop->setStep(step); - FilterFunctionType notVectorizedThisPattern = [state]( - const Instruction &inst) { + FilterFunctionType notVectorizedThisPattern = [state](Instruction &inst) { if (!matcher::isLoadOrStore(inst)) { return false; } @@ -893,7 +892,7 @@ static LogicalResult vectorizeAffineForOp(AffineForOp *loop, int64_t step, /// we can build a cost model and a search procedure. static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { - return [fastestVaryingMemRefDimension](const Instruction &forInst) { + return [fastestVaryingMemRefDimension](Instruction &forInst) { auto loop = forInst.cast(); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 0003069117e8..015a889beb22 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -47,5 +47,5 @@ def NS_AOp : Op<"a_op", [NoSideEffect]> { // CHECK: bool fold(SmallVectorImpl &results); // CHECK: private: // CHECK: friend class ::mlir::Instruction; -// CHECK: explicit AOp(const Instruction *state) : Op(state) {} +// CHECK: explicit AOp(Instruction *state) : Op(state) {} // CHECK: }; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 93418cff9a02..7d31dde91566 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -321,8 +321,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const { } os << "\nprivate:\n" << " friend class ::mlir::Instruction;\n"; - os << " explicit " << className - << "(const Instruction *state) : Op(state) {}\n" + os << " explicit " << className << "(Instruction *state) : Op(state) {}\n" << "};"; }