From 3d6c74fff5347cf11b9cc4149601499d048e51b8 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Thu, 21 Mar 2019 17:53:00 -0700 Subject: [PATCH] Remove const from mlir::Block. This also eliminates some incorrect reinterpret_cast logic working around it, and numerous const-incorrect issues (like block argument iteration). PiperOrigin-RevId: 239712029 --- mlir/include/mlir/AffineOps/AffineOps.h | 4 +- mlir/include/mlir/Analysis/Dominance.h | 10 +- mlir/include/mlir/EDSC/Builders.h | 4 +- mlir/include/mlir/IR/Block.h | 151 +++++--------------- mlir/include/mlir/IR/BlockAndValueMapping.h | 4 +- mlir/include/mlir/IR/FunctionGraphTraits.h | 29 ---- mlir/include/mlir/IR/Instruction.h | 10 +- mlir/include/mlir/IR/OpDefinition.h | 5 +- mlir/include/mlir/IR/Value.h | 3 +- mlir/include/mlir/StandardOps/Ops.h | 8 +- mlir/lib/AffineOps/AffineOps.cpp | 4 +- mlir/lib/Analysis/AffineAnalysis.cpp | 8 +- mlir/lib/Analysis/Dominance.cpp | 3 +- mlir/lib/Analysis/Utils.cpp | 22 ++- mlir/lib/Analysis/Verifier.cpp | 14 +- mlir/lib/IR/AsmPrinter.cpp | 26 ++-- mlir/lib/IR/Block.cpp | 27 ++-- mlir/lib/IR/Operation.cpp | 4 +- mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 29 ++-- mlir/lib/Transforms/DmaGeneration.cpp | 13 +- mlir/lib/Transforms/ViewFunctionGraph.cpp | 4 +- 21 files changed, 131 insertions(+), 251 deletions(-) diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index a9b93ba29183..63dc20f6b433 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -142,11 +142,9 @@ public: Block *createBody(); /// Get the body of the AffineForOp. - Block *getBody() { return &getRegion().front(); } - const Block *getBody() const { return &getRegion().front(); } + Block *getBody() const { return &getRegion().front(); } /// Get the body region of the AffineForOp. - Region &getRegion() { return getInstruction()->getRegion(0); } Region &getRegion() const { return getInstruction()->getRegion(0); } /// Returns the induction variable for this loop. diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index 39c95b8fbe90..d88c002a2748 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -53,7 +53,7 @@ protected: using super = DominanceInfoBase; /// Return true if the specified block A properly dominates block B. - bool properlyDominates(const Block *a, const Block *b); + bool properlyDominates(Block *a, Block *b); /// A mapping of regions to their base dominator tree. llvm::DenseMap> dominanceInfos; @@ -82,12 +82,12 @@ public: } /// Return true if the specified block A dominates block B. - bool dominates(const Block *a, const Block *b) { + bool dominates(Block *a, Block *b) { return a == b || properlyDominates(a, b); } /// Return true if the specified block A properly dominates block B. - bool properlyDominates(const Block *a, const Block *b) { + bool properlyDominates(Block *a, Block *b) { return super::properlyDominates(a, b); } }; @@ -106,12 +106,12 @@ public: } /// Return true if the specified block A properly postdominates block B. - bool properlyPostDominates(const Block *a, const Block *b) { + bool properlyPostDominates(Block *a, Block *b) { return super::properlyDominates(a, b); } /// Return true if the specified block A postdominates block B. - bool postDominates(const Block *a, const Block *b) { + bool postDominates(Block *a, Block *b) { return a == b || properlyPostDominates(a, b); } }; diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 10dc81bcd08b..38d3bf32dbcf 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -230,8 +230,8 @@ public: void operator()(ArrayRef stmts); private: - BlockBuilder(const BlockBuilder &) = delete; - BlockBuilder &operator=(const BlockBuilder &other) = delete; + BlockBuilder(BlockBuilder &) = delete; + BlockBuilder &operator=(BlockBuilder &other) = delete; }; /// Base class for ValueHandle, InstructionHandle and BlockHandle. diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 6f1196ba8829..f373f73bf566 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -76,8 +76,8 @@ class Function; using BlockOperand = IROperandImpl; -template class PredecessorIterator; -template class SuccessorIterator; +class PredecessorIterator; +class SuccessorIterator; /// `Block` represents an ordered list of `Instruction`s. class Block : public IRObjectWithUseList, @@ -97,19 +97,15 @@ public: } /// Blocks are maintained in a Region. - Region *getParent() const { return parentValidInstOrderPair.getPointer(); } + Region *getParent() { return parentValidInstOrderPair.getPointer(); } /// Returns the closest surrounding instruction that contains this block or /// nullptr if this is a top-level block. Instruction *getContainingInst(); - const Instruction *getContainingInst() const { - return const_cast(this)->getContainingInst(); - } - /// Returns the function that this block is part of, even if the block is /// nested under an operation region. - Function *getFunction() const; + Function *getFunction(); /// Insert this block (which must not already be in a function) right before /// the specified block. @@ -125,17 +121,16 @@ public: // This is the list of arguments to the block. using BlockArgListType = ArrayRef; - // FIXME: Not const correct. - BlockArgListType getArguments() const { return arguments; } + BlockArgListType getArguments() { return arguments; } using args_iterator = BlockArgListType::iterator; using reverse_args_iterator = BlockArgListType::reverse_iterator; - args_iterator args_begin() const { return getArguments().begin(); } - args_iterator args_end() const { return getArguments().end(); } - reverse_args_iterator args_rbegin() const { return getArguments().rbegin(); } - reverse_args_iterator args_rend() const { return getArguments().rend(); } + args_iterator args_begin() { return getArguments().begin(); } + args_iterator args_end() { return getArguments().end(); } + reverse_args_iterator args_rbegin() { return getArguments().rbegin(); } + reverse_args_iterator args_rend() { return getArguments().rend(); } - bool args_empty() const { return arguments.empty(); } + bool args_empty() { return arguments.empty(); } /// Add one value to the argument list. BlockArgument *addArgument(Type type); @@ -146,9 +141,8 @@ public: /// Erase the argument at 'index' and remove it from the argument list. void eraseArgument(unsigned index); - unsigned getNumArguments() const { return arguments.size(); } + unsigned getNumArguments() { return arguments.size(); } BlockArgument *getArgument(unsigned i) { return arguments[i]; } - const BlockArgument *getArgument(unsigned i) const { return arguments[i]; } //===--------------------------------------------------------------------===// // Instruction list management @@ -157,44 +151,29 @@ public: /// This is the list of instructions in the block. using InstListType = llvm::iplist; InstListType &getInstructions() { return instructions; } - const InstListType &getInstructions() const { return instructions; } // Iteration over the instructions in the block. using iterator = InstListType::iterator; - using const_iterator = InstListType::const_iterator; using reverse_iterator = InstListType::reverse_iterator; - using const_reverse_iterator = InstListType::const_reverse_iterator; iterator begin() { return instructions.begin(); } iterator end() { return instructions.end(); } - const_iterator begin() const { return instructions.begin(); } - const_iterator end() const { return instructions.end(); } reverse_iterator rbegin() { return instructions.rbegin(); } reverse_iterator rend() { return instructions.rend(); } - const_reverse_iterator rbegin() const { return instructions.rbegin(); } - const_reverse_iterator rend() const { return instructions.rend(); } - bool empty() const { return instructions.empty(); } + bool empty() { return instructions.empty(); } void push_back(Instruction *inst) { instructions.push_back(inst); } void push_front(Instruction *inst) { instructions.push_front(inst); } Instruction &back() { return instructions.back(); } - const Instruction &back() const { return const_cast(this)->back(); } Instruction &front() { return instructions.front(); } - const Instruction &front() const { - return const_cast(this)->front(); - } /// Returns 'inst' if 'inst' lies in this block, or otherwise finds the /// ancestor instruction of 'inst' that lies in this block. Returns nullptr if /// the latter fails. /// TODO: This is very specific functionality that should live somewhere else, /// probably in Dominance.cpp. - Instruction *findAncestorInstInBlock(Instruction *inst); - const Instruction *findAncestorInstInBlock(const Instruction &inst) const { - return const_cast(this)->findAncestorInstInBlock( - const_cast(&inst)); - } + Instruction *findAncestorInstInBlock(const Instruction &inst); /// This drops all operand uses from instructions within this block, which is /// an essential step in breaking cyclic dependences between references when @@ -203,7 +182,7 @@ public: /// Returns true if the ordering of the child instructions is valid, false /// otherwise. - bool isInstOrderValid() const { return parentValidInstOrderPair.getInt(); } + bool isInstOrderValid() { return parentValidInstOrderPair.getInt(); } /// Invalidates the current ordering of instructions. void invalidateInstOrder() { @@ -214,7 +193,7 @@ public: /// Verifies the current ordering of child instructions matches the /// validInstOrder flag. Returns false if the order is valid, true otherwise. - bool verifyInstOrder() const; + bool verifyInstOrder(); /// Recomputes the ordering of child instructions within the block. void recomputeInstOrder(); @@ -227,27 +206,18 @@ public: /// the block has a valid terminator instruction. Instruction *getTerminator(); - const Instruction *getTerminator() const { - return const_cast(this)->getTerminator(); - } - //===--------------------------------------------------------------------===// // Predecessors and successors. //===--------------------------------------------------------------------===// // Predecessor iteration. - using const_pred_iterator = PredecessorIterator; - const_pred_iterator pred_begin() const; - const_pred_iterator pred_end() const; - llvm::iterator_range getPredecessors() const; - - using pred_iterator = PredecessorIterator; + using pred_iterator = PredecessorIterator; pred_iterator pred_begin(); pred_iterator pred_end(); llvm::iterator_range getPredecessors(); /// Return true if this block has no predecessors. - bool hasNoPredecessors() const; + bool hasNoPredecessors(); /// If this block has exactly one predecessor, return it. Otherwise, return /// null. @@ -257,24 +227,12 @@ public: /// destinations) is not considered to be a single predecessor. Block *getSinglePredecessor(); - const Block *getSinglePredecessor() const { - return const_cast(this)->getSinglePredecessor(); - } - // Indexed successor access. - unsigned getNumSuccessors() const; - const Block *getSuccessor(unsigned i) const { - return const_cast(this)->getSuccessor(i); - } + unsigned getNumSuccessors(); Block *getSuccessor(unsigned i); // Successor iteration. - using const_succ_iterator = SuccessorIterator; - const_succ_iterator succ_begin() const; - const_succ_iterator succ_end() const; - llvm::iterator_range getSuccessors() const; - - using succ_iterator = SuccessorIterator; + using succ_iterator = SuccessorIterator; succ_iterator succ_begin(); succ_iterator succ_end(); llvm::iterator_range getSuccessors(); @@ -325,8 +283,8 @@ public: return &Block::instructions; } - void print(raw_ostream &os) const; - void dump() const; + void print(raw_ostream &os); + void dump(); /// Print out the name of the block without printing its body. /// NOTE: The printType argument is ignored. We keep it for compatibility @@ -344,8 +302,8 @@ private: /// This is the list of arguments to the block. std::vector arguments; - Block(const Block &) = delete; - void operator=(const Block &) = delete; + Block(Block &) = delete; + void operator=(Block &) = delete; friend struct llvm::ilist_traits; }; @@ -437,28 +395,23 @@ private: /// BlockOperands that are embedded into terminator instructions. From the /// operand, we can get the terminator that contains it, and it's parent block /// is the predecessor. -template class PredecessorIterator - : public llvm::iterator_facade_base, - std::forward_iterator_tag, - BlockType *> { + : public llvm::iterator_facade_base { public: PredecessorIterator(BlockOperand *firstOperand) : bbUseIterator(firstOperand) {} PredecessorIterator &operator=(const PredecessorIterator &rhs) { bbUseIterator = rhs.bbUseIterator; + return *this; } bool operator==(const PredecessorIterator &rhs) const { return bbUseIterator == rhs.bbUseIterator; } - BlockType *operator*() const { - // The use iterator points to an operand of a terminator. The predecessor - // we return is the block that the terminator is embedded into. - return bbUseIterator.getUser()->getBlock(); - } + Block *operator*() const; PredecessorIterator &operator++() { ++bbUseIterator; @@ -466,28 +419,13 @@ public: } /// Get the successor number in the predecessor terminator. - unsigned getSuccessorIndex() const { - return bbUseIterator->getOperandNumber(); - } + unsigned getSuccessorIndex() const; private: using BBUseIterator = ValueUseIterator; BBUseIterator bbUseIterator; }; -inline auto Block::pred_begin() const -> const_pred_iterator { - return const_pred_iterator((BlockOperand *)getFirstUse()); -} - -inline auto Block::pred_end() const -> const_pred_iterator { - return const_pred_iterator(nullptr); -} - -inline auto Block::getPredecessors() const - -> llvm::iterator_range { - return {pred_begin(), pred_end()}; -} - inline auto Block::pred_begin() -> pred_iterator { return pred_iterator((BlockOperand *)getFirstUse()); } @@ -505,46 +443,23 @@ inline auto Block::getPredecessors() -> llvm::iterator_range { //===----------------------------------------------------------------------===// /// This template implements the successor iterators for Block. -template class SuccessorIterator final - : public IndexedAccessorIterator, BlockType, - BlockType> { + : public IndexedAccessorIterator { public: /// Initializes the result iterator to the specified index. - SuccessorIterator(BlockType *object, unsigned index) - : IndexedAccessorIterator, BlockType, - BlockType>(object, index) {} + SuccessorIterator(Block *object, unsigned index) + : IndexedAccessorIterator(object, + index) {} SuccessorIterator(const SuccessorIterator &other) : SuccessorIterator(other.object, other.index) {} - /// Support converting to the const variant. This will be a no-op for const - /// variant. - operator SuccessorIterator() const { - return SuccessorIterator(this->object, this->index); - } - - BlockType *operator*() const { - return this->object->getSuccessor(this->index); - } + Block *operator*() const { return this->object->getSuccessor(this->index); } /// Get the successor number in the terminator. unsigned getSuccessorIndex() const { return this->index; } }; -inline auto Block::succ_begin() const -> const_succ_iterator { - return const_succ_iterator(this, 0); -} - -inline auto Block::succ_end() const -> const_succ_iterator { - return const_succ_iterator(this, getNumSuccessors()); -} - -inline auto Block::getSuccessors() const - -> llvm::iterator_range { - return {succ_begin(), succ_end()}; -} - inline auto Block::succ_begin() -> succ_iterator { return succ_iterator(this, 0); } diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h index cc0a6c064573..2bac95bc39db 100644 --- a/mlir/include/mlir/IR/BlockAndValueMapping.h +++ b/mlir/include/mlir/IR/BlockAndValueMapping.h @@ -36,7 +36,7 @@ class BlockAndValueMapping { public: /// Inserts a new mapping for 'from' to 'to'. If there is an existing mapping, /// it is overwritten. - void map(const Block *from, Block *to) { valueMap[from] = to; } + void map(Block *from, Block *to) { valueMap[from] = to; } void map(const Value *from, Value *to) { valueMap[from] = to; } /// Erases a mapping for 'from'. @@ -49,7 +49,7 @@ public: /// Lookup a mapped value within the map. If a mapping for the provided value /// does not exist then return nullptr. - Block *lookupOrNull(const Block *from) const { + Block *lookupOrNull(Block *from) const { return lookupOrValue(from, (Block *)nullptr); } Value *lookupOrNull(const Value *from) const { diff --git a/mlir/include/mlir/IR/FunctionGraphTraits.h b/mlir/include/mlir/IR/FunctionGraphTraits.h index a47b97ead8a2..adc3fef9ff66 100644 --- a/mlir/include/mlir/IR/FunctionGraphTraits.h +++ b/mlir/include/mlir/IR/FunctionGraphTraits.h @@ -41,19 +41,6 @@ template <> struct GraphTraits { static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); } }; -template <> struct GraphTraits { - using ChildIteratorType = mlir::Block::const_succ_iterator; - using Node = const mlir::Block; - using NodeRef = Node *; - - static NodeRef getEntryNode(NodeRef bb) { return bb; } - - static ChildIteratorType child_begin(NodeRef node) { - return node->succ_begin(); - } - static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); } -}; - template <> struct GraphTraits> { using ChildIteratorType = mlir::Block::pred_iterator; using Node = mlir::Block; @@ -69,22 +56,6 @@ template <> struct GraphTraits> { } }; -template <> struct GraphTraits> { - using ChildIteratorType = mlir::Block::const_pred_iterator; - using Node = const mlir::Block; - using NodeRef = Node *; - - static NodeRef getEntryNode(Inverse inverseGraph) { - return inverseGraph.Graph; - } - static inline ChildIteratorType child_begin(NodeRef node) { - return node->pred_begin(); - } - static inline ChildIteratorType child_end(NodeRef node) { - return node->pred_end(); - } -}; - template <> struct GraphTraits : public GraphTraits { using GraphType = mlir::Function *; diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index 7a12d4f543ad..6ffe916577c2 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -88,8 +88,7 @@ public: Instruction *clone(MLIRContext *context) const; /// Returns the instruction block that contains this instruction. - const Block *getBlock() const { return block; } - Block *getBlock() { return block; } + Block *getBlock() const { return block; } /// Return the context this operation is associated with. MLIRContext *getContext() const; @@ -337,13 +336,10 @@ public: return getTrailingObjects()[index]; } - Block *getSuccessor(unsigned index) { + Block *getSuccessor(unsigned index) const { assert(index < getNumSuccessors()); return getBlockOperands()[index].get(); } - const Block *getSuccessor(unsigned index) const { - return const_cast(this)->getSuccessor(index); - } void setSuccessor(Block *block, unsigned index); /// Erase a specific operand from the operand list of the successor at @@ -517,7 +513,7 @@ private: } // Provide a 'getParent' method for ilist_node_with_parent methods. - const Block *getParent() const { return getBlock(); } + Block *getParent() const { return getBlock(); } /// The instruction block that containts this instruction. Block *block = nullptr; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index cda33d58b4f6..89d3bff7a814 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -787,10 +787,7 @@ public: return this->getInstruction()->getNumSuccessorOperands(index); } - const Block *getSuccessor(unsigned index) const { - return this->getInstruction()->getSuccessor(index); - } - Block *getSuccessor(unsigned index) { + Block *getSuccessor(unsigned index) const { return this->getInstruction()->getSuccessor(index); } diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index c34474f17b29..fde405305961 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -116,8 +116,7 @@ public: /// Return the function that this argument is defined in. Function *getFunction() const; - Block *getOwner() { return owner; } - const Block *getOwner() const { return owner; } + Block *getOwner() const { return owner; } /// Returns the number of this argument. unsigned getArgNumber() const; diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index c546f1d77c73..df828fc0b9fb 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -113,9 +113,7 @@ public: /// Return the block this branch jumps to. Block *getDest(); - const Block *getDest() const { - return const_cast(this)->getDest(); - } + Block *getDest() const { return const_cast(this)->getDest(); } void setDest(Block *block); /// Erase the operand at 'index' from the operand list. @@ -322,13 +320,13 @@ public: /// Return the destination if the condition is true. Block *getTrueDest(); - const Block *getTrueDest() const { + Block *getTrueDest() const { return const_cast(this)->getTrueDest(); } /// Return the destination if the condition is false. Block *getFalseDest(); - const Block *getFalseDest() const { + Block *getFalseDest() const { return const_cast(this)->getFalseDest(); } diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index cd73916a21b5..4a77a3254363 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -576,7 +576,7 @@ bool AffineForOp::verify() const { // Check that the body defines as single block argument for the induction // variable. - const auto *body = getBody(); + auto *body = getBody(); if (body->getNumArguments() != 1 || !body->getArgument(0)->getType().isIndex()) return emitOpError("expected body to have a single index argument for the " @@ -1068,7 +1068,7 @@ bool AffineIfOp::verify() const { if (region.front().back().isKnownTerminator()) return emitOpError("expects region block to not have a terminator"); - for (const auto &b : region) + for (auto &b : region) if (b.getNumArguments() != 0) return emitOpError( "requires that child entry blocks have no arguments"); diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 8b845af7e125..c24a7688a4dc 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -532,10 +532,10 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain, } // Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. -static const Block *getCommonBlock(const MemRefAccess &srcAccess, - const MemRefAccess &dstAccess, - const FlatAffineConstraints &srcDomain, - unsigned numCommonLoops) { +static Block *getCommonBlock(const MemRefAccess &srcAccess, + const MemRefAccess &dstAccess, + const FlatAffineConstraints &srcDomain, + unsigned numCommonLoops) { if (numCommonLoops == 0) { auto *block = srcAccess.opInst->getBlock(); while (block->getContainingInst()) { diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 8ccee5d4e29e..50fb2586f7d0 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -60,8 +60,7 @@ void DominanceInfoBase::recalculate(Function *function) { /// Return true if the specified block A properly dominates block B. template -bool DominanceInfoBase::properlyDominates(const Block *a, - const Block *b) { +bool DominanceInfoBase::properlyDominates(Block *a, Block *b) { // A block dominates itself but does not properly dominate itself. if (a == b) return false; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index be476b50efb9..8918dd03f809 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -433,11 +433,12 @@ template LogicalResult mlir::boundCheckLoadOrStoreOp(OpPointer storeOp, // Block from the Block containing instruction, stopping at 'limitBlock'. static void findInstPosition(const Instruction *inst, Block *limitBlock, SmallVectorImpl *positions) { - const Block *block = inst->getBlock(); + Block *block = inst->getBlock(); while (block != limitBlock) { // FIXME: This algorithm is unnecessarily O(n) and should be improved to not // rely on linear scans. - int instPosInBlock = std::distance(block->begin(), inst->getIterator()); + int instPosInBlock = std::distance( + block->begin(), const_cast(inst)->getIterator()); positions->push_back(instPosInBlock); inst = block->getContainingInst(); block = inst->getBlock(); @@ -680,20 +681,15 @@ unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A, return numCommonLoops; } -static Optional getMemoryFootprintBytes(const Block &block, - Block::const_iterator start, - Block::const_iterator end, +static Optional getMemoryFootprintBytes(Block &block, + Block::iterator start, + Block::iterator end, int memorySpace) { SmallDenseMap, 4> regions; - // Cast away constness since the walker uses non-const versions; but we - // guarantee that the visitor callback isn't mutating opInst. - auto *cStart = reinterpret_cast(&start); - auto *cEnd = reinterpret_cast(&end); - // Walk this 'for' instruction to gather all memory regions. bool error = false; - const_cast(&block)->walk(*cStart, *cEnd, [&](Instruction *opInst) { + const_cast(&block)->walk(start, end, [&](Instruction *opInst) { if (!opInst->isa() && !opInst->isa()) { // Neither load nor a store op. return; @@ -737,8 +733,8 @@ Optional mlir::getMemoryFootprintBytes(OpPointer forOp, int memorySpace) { auto *forInst = forOp->getInstruction(); return ::getMemoryFootprintBytes( - *forInst->getBlock(), Block::const_iterator(forInst), - std::next(Block::const_iterator(forInst)), memorySpace); + *forInst->getBlock(), Block::iterator(forInst), + std::next(Block::iterator(forInst)), memorySpace); } /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index faab0cd7b795..51651302f8d3 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -60,7 +60,7 @@ public: return fn.emitError(message); } - bool failure(const Twine &message, const Block &bb) { + bool failure(const Twine &message, Block &bb) { // Take the location information for the first instruction in the block. if (!bb.empty()) return failure(message, bb.front()); @@ -107,9 +107,9 @@ public: } bool verify(); - bool verifyBlock(const Block &block, bool isTopLevel); + bool verifyBlock(Block &block, bool isTopLevel); bool verifyOperation(const Instruction &op); - bool verifyDominance(const Block &block); + bool verifyDominance(Block &block); bool verifyInstDominance(const Instruction &inst); explicit FuncVerifier(Function &fn) @@ -221,7 +221,7 @@ bool FuncVerifier::verify() { } // Returns if the given block is allowed to have no terminator. -static bool canBlockHaveNoTerminator(const Block &block) { +static bool canBlockHaveNoTerminator(Block &block) { // Allow the first block of an operation region to have no terminator if it is // the only block in the region. auto *parentList = block.getParent(); @@ -229,7 +229,7 @@ static bool canBlockHaveNoTerminator(const Block &block) { std::next(parentList->begin()) == parentList->end(); } -bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { +bool FuncVerifier::verifyBlock(Block &block, bool isTopLevel) { for (auto *arg : block.getArguments()) { if (arg->getOwner() != &block) return failure("block argument not owned by block", block); @@ -262,7 +262,7 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { // Verify that this block is not branching to a block of a different // region. - for (const Block *successor : block.getSuccessors()) + for (Block *successor : block.getSuccessors()) if (successor->getParent() != block.getParent()) return failure("branching to block of a different region", block.back()); @@ -314,7 +314,7 @@ bool FuncVerifier::verifyOperation(const Instruction &op) { return false; } -bool FuncVerifier::verifyDominance(const Block &block) { +bool FuncVerifier::verifyDominance(Block &block) { // Verify the dominance of each of the held instructions. for (auto &inst : block) if (verifyInstDominance(inst)) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 78f40a55670e..7d60b5819f3a 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1061,7 +1061,7 @@ public: // Methods to print instructions. void print(const Instruction *inst); - void print(const Block *block, bool printBlockArgs = true); + void print(Block *block, bool printBlockArgs = true); void printOperation(const Instruction *op); void printGenericOp(const Instruction *op); @@ -1094,7 +1094,7 @@ public: enum { nameSentinel = ~0U }; - void printBlockName(const Block *block) { + void printBlockName(Block *block) { auto id = getBlockID(block); if (id != ~0U) os << "^bb" << id; @@ -1102,7 +1102,7 @@ public: os << "^INVALIDBLOCK"; } - unsigned getBlockID(const Block *block) { + unsigned getBlockID(Block *block) { auto it = blockIDs.find(block); return it != blockIDs.end() ? it->second : ~0U; } @@ -1128,7 +1128,7 @@ public: protected: void numberValueID(const Value *value); - void numberValuesInBlock(const Block &block); + void numberValuesInBlock(Block &block); void printValueID(const Value *value, bool printResultNo = true) const; private: @@ -1140,7 +1140,7 @@ private: DenseMap valueNames; /// This is the block ID for each block in the current function. - DenseMap blockIDs; + DenseMap blockIDs; /// This keeps track of all of the non-numeric names that are in flight, /// allowing us to check for duplicates. @@ -1172,7 +1172,7 @@ FunctionPrinter::FunctionPrinter(Function *function, ModulePrinter &other) /// Number all of the SSA values in the specified block. Values get numbered /// continuously throughout regions. In particular, we traverse the regions /// held by operations and number values in depth-first pre-order. -void FunctionPrinter::numberValuesInBlock(const Block &block) { +void FunctionPrinter::numberValuesInBlock(Block &block) { // Each block gets a unique ID, and all of the instructions within it get // numbered as well. blockIDs[&block] = nextBlockID++; @@ -1186,7 +1186,7 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { if (inst.getNumResults() != 0) numberValueID(inst.getResult(0)); for (auto ®ion : inst.getRegions()) - for (const auto &block : region) + for (auto &block : region) numberValuesInBlock(block); } } @@ -1337,7 +1337,7 @@ void FunctionPrinter::printFunctionSignature() { } } -void FunctionPrinter::print(const Block *block, bool printBlockArgs) { +void FunctionPrinter::print(Block *block, bool printBlockArgs) { // Print the block label and argument list if requested. if (printBlockArgs) { os.indent(currentIndent); @@ -1346,7 +1346,7 @@ void FunctionPrinter::print(const Block *block, bool printBlockArgs) { // Print the argument list if non-empty. if (!block->args_empty()) { os << '('; - interleaveComma(block->getArguments(), [&](const BlockArgument *arg) { + interleaveComma(block->getArguments(), [&](BlockArgument *arg) { printValueID(arg); os << ": "; printType(arg->getType()); @@ -1366,14 +1366,14 @@ void FunctionPrinter::print(const Block *block, bool printBlockArgs) { } else { // We want to print the predecessors in increasing numeric order, not in // whatever order the use-list is in, so gather and sort them. - SmallVector, 4> predIDs; + SmallVector, 4> predIDs; for (auto *pred : block->getPredecessors()) predIDs.push_back({getBlockID(pred), pred}); llvm::array_pod_sort(predIDs.begin(), predIDs.end()); os << "\t// " << predIDs.size() << " preds: "; - interleaveComma(predIDs, [&](std::pair pred) { + interleaveComma(predIDs, [&](std::pair pred) { printBlockName(pred.second); }); } @@ -1615,7 +1615,7 @@ void Instruction::dump() const { llvm::errs() << "\n"; } -void Block::print(raw_ostream &os) const { +void Block::print(raw_ostream &os) { auto *function = getFunction(); if (!function) { os << "<>\n"; @@ -1627,7 +1627,7 @@ void Block::print(raw_ostream &os) const { FunctionPrinter(function, modulePrinter).print(this); } -void Block::dump() const { print(llvm::errs()); } +void Block::dump() { print(llvm::errs()); } /// Print out the name of the block without printing its body. void Block::printAsOperand(raw_ostream &os, bool printType) { diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 1e3c79f491e8..0470eb5e13bd 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -49,8 +49,8 @@ Instruction *Block::getContainingInst() { return getParent() ? getParent()->getContainingInst() : nullptr; } -Function *Block::getFunction() const { - const Block *block = this; +Function *Block::getFunction() { + Block *block = this; while (auto *inst = block->getContainingInst()) { block = inst->getBlock(); if (!block) @@ -78,10 +78,10 @@ void Block::eraseFromFunction() { /// Returns 'inst' if 'inst' lies in this block, or otherwise finds the /// ancestor instruction of 'inst' that lies in this block. Returns nullptr if /// the latter fails. -Instruction *Block::findAncestorInstInBlock(Instruction *inst) { +Instruction *Block::findAncestorInstInBlock(const Instruction &inst) { // Traverse up the instruction hierarchy starting from the owner of operand to // find the ancestor instruction that resides in the block of 'forInst'. - auto *currInst = inst; + auto *currInst = const_cast(&inst); while (currInst->getBlock() != this) { currInst = currInst->getParentInst(); if (!currInst) @@ -100,7 +100,7 @@ void Block::dropAllReferences() { /// Verifies the current ordering of child instructions. Returns false if the /// order is valid, true otherwise. -bool Block::verifyInstOrder() const { +bool Block::verifyInstOrder() { // The order is already known to be invalid. if (!isInstOrderValid()) return false; @@ -131,6 +131,17 @@ void Block::recomputeInstOrder() { inst.orderIndex = orderIndex++; } +Block *PredecessorIterator::operator*() const { + // The use iterator points to an operand of a terminator. The predecessor + // we return is the block that the terminator is embedded into. + return bbUseIterator.getUser()->getBlock(); +} + +/// Get the successor number in the predecessor terminator. +unsigned PredecessorIterator::getSuccessorIndex() const { + return bbUseIterator->getOperandNumber(); +} + //===----------------------------------------------------------------------===// // Argument list management. //===----------------------------------------------------------------------===// @@ -179,10 +190,10 @@ Instruction *Block::getTerminator() { } /// Return true if this block has no predecessors. -bool Block::hasNoPredecessors() const { return pred_begin() == pred_end(); } +bool Block::hasNoPredecessors() { return pred_begin() == pred_end(); } // Indexed successor access. -unsigned Block::getNumSuccessors() const { +unsigned Block::getNumSuccessors() { return empty() ? 0 : back().getNumSuccessors(); } @@ -288,7 +299,7 @@ void Region::cloneInto(Region *dest, BlockAndValueMapping &mapper, return; iterator lastOldBlock = --dest->end(); - for (const Block &block : *this) { + for (Block &block : *this) { Block *newBlock = new Block(); mapper.map(&block, newBlock); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 011e1d27cef6..566e5b16bff0 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -246,7 +246,7 @@ bool OpTrait::impl::verifySameOperandsAndResultType(const Instruction *op) { static bool verifyBBArguments( llvm::iterator_range operands, - const Block *destBB, const Instruction *op) { + Block *destBB, const Instruction *op) { unsigned operandCount = std::distance(operands.begin(), operands.end()); if (operandCount != destBB->getNumArguments()) return op->emitError("branch has " + Twine(operandCount) + @@ -276,7 +276,7 @@ static bool verifyTerminatorSuccessors(const Instruction *op) { } bool OpTrait::impl::verifyIsTerminator(const Instruction *op) { - const Block *block = op->getBlock(); + Block *block = op->getBlock(); // Verify that the operation is at the end of the respective parent block. if (!block || &block->back() != op) return op->emitOpError("must be the last instruction in the parent block"); diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index f6bc644eb27f..7c74c2fb2f65 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -58,7 +58,7 @@ private: bool convertFunctions(); bool convertOneFunction(Function &func); void connectPHINodes(Function &func); - bool convertBlock(const Block &bb, bool ignoreArguments); + bool convertBlock(Block &bb, bool ignoreArguments); bool convertInstruction(const Instruction &inst, llvm::IRBuilder<> &builder); template @@ -74,7 +74,7 @@ private: // Mappings between original and translated values, used for lookups. llvm::DenseMap functionMapping; llvm::DenseMap valueMapping; - llvm::DenseMap blockMapping; + llvm::DenseMap blockMapping; }; } // end anonymous namespace @@ -257,7 +257,7 @@ bool ModuleTranslation::convertInstruction(const Instruction &inst, // Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes // to define values corresponding to the MLIR block arguments. These nodes // are not connected to the source basic blocks, which may not exist yet. -bool ModuleTranslation::convertBlock(const Block &bb, bool ignoreArguments) { +bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { llvm::IRBuilder<> builder(blockMapping[&bb]); // Before traversing instructions, make block arguments available through @@ -294,7 +294,7 @@ bool ModuleTranslation::convertBlock(const Block &bb, bool ignoreArguments) { // Get the SSA value passed to the current block from the terminator instruction // of its predecessor. -static const Value *getPHISourceValue(const Block *current, const Block *pred, +static const Value *getPHISourceValue(Block *current, Block *pred, unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); if (terminator.isa()) { @@ -320,7 +320,7 @@ void ModuleTranslation::connectPHINodes(Function &func) { // Skip the first block, it cannot be branched to and its arguments correspond // to the arguments of the LLVM function. for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { - const Block *bb = &*it; + Block *bb = &*it; llvm::BasicBlock *llvmBB = blockMapping.lookup(bb); auto phis = llvmBB->phis(); auto numArguments = bb->getNumArguments(); @@ -328,7 +328,7 @@ void ModuleTranslation::connectPHINodes(Function &func) { for (auto &numberedPhiNode : llvm::enumerate(phis)) { auto &phiNode = numberedPhiNode.value(); unsigned index = numberedPhiNode.index(); - for (const auto *pred : bb->getPredecessors()) { + for (auto *pred : bb->getPredecessors()) { phiNode.addIncoming(valueMapping.lookup(getPHISourceValue( bb, pred, numArguments, index)), blockMapping.lookup(pred)); @@ -338,22 +338,21 @@ void ModuleTranslation::connectPHINodes(Function &func) { } // TODO(mlir-team): implement an iterative version -static void topologicalSortImpl(llvm::SetVector &blocks, - const Block *b) { +static void topologicalSortImpl(llvm::SetVector &blocks, Block *b) { blocks.insert(b); - for (const Block *bb : b->getSuccessors()) { + for (Block *bb : b->getSuccessors()) { if (blocks.count(bb) == 0) topologicalSortImpl(blocks, bb); } } // Sort function blocks topologically. -static llvm::SetVector topologicalSort(Function &f) { +static llvm::SetVector topologicalSort(Function &f) { // For each blocks that has not been visited yet (i.e. that has no // predecessors), add it to the list and traverse its successors in DFS // preorder. - llvm::SetVector blocks; - for (const Block &b : f.getBlocks()) { + llvm::SetVector blocks; + for (Block &b : f.getBlocks()) { if (blocks.count(&b) == 0) topologicalSortImpl(blocks, &b); } @@ -373,7 +372,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) { unsigned int argIdx = 0; for (const auto &kvp : llvm::zip(func.getArguments(), llvmFunc->args())) { llvm::Argument &llvmArg = std::get<1>(kvp); - const BlockArgument *mlirArg = std::get<0>(kvp); + BlockArgument *mlirArg = std::get<0>(kvp); if (auto attr = func.getArgAttrOfType(argIdx, "llvm.noalias")) { // NB: Attribute already verified to be boolean, so check if we can indeed @@ -392,7 +391,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) { // First, create all blocks so we can jump to them. llvm::LLVMContext &llvmContext = llvmFunc->getContext(); - for (const auto &bb : func) { + for (auto &bb : func) { auto *llvmBB = llvm::BasicBlock::Create(llvmContext); llvmBB->insertInto(llvmFunc); blockMapping[&bb] = llvmBB; @@ -402,7 +401,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) { // converted before uses. auto blocks = topologicalSort(func); for (auto indexedBB : llvm::enumerate(blocks)) { - const auto *bb = indexedBB.value(); + auto *bb = indexedBB.value(); if (convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)) return true; } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 23e07fc3a89f..d97538734d1e 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -205,7 +205,7 @@ static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs, return true; } -static void emitNoteForBlock(const Block &block, const Twine &message) { +static void emitNoteForBlock(Block &block, const Twine &message) { auto *inst = block.getContainingInst(); if (!inst) { block.getFunction()->emitNote(message); @@ -543,11 +543,12 @@ bool DmaGeneration::runOnBlock(Block *block) { /// in the block for placing incoming (read) and outgoing (write) DMAs /// respectively. The lowest depth depends on whether the region being accessed /// is invariant with respect to one or more immediately surrounding loops. -static void findHighestBlockForPlacement( - const MemRefRegion ®ion, const Block &block, - const Block::iterator &begin, const Block::iterator &end, - Block **dmaPlacementBlock, Block::iterator *dmaPlacementReadStart, - Block::iterator *dmaPlacementWriteStart) { +static void +findHighestBlockForPlacement(const MemRefRegion ®ion, Block &block, + Block::iterator &begin, Block::iterator &end, + Block **dmaPlacementBlock, + Block::iterator *dmaPlacementReadStart, + Block::iterator *dmaPlacementWriteStart) { const auto *cst = region.getConstraints(); SmallVector symbols; cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols); diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index f4e8e44cfdcc..834424951bfc 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -28,10 +28,10 @@ template <> struct llvm::DOTGraphTraits : public DefaultDOTGraphTraits { using DefaultDOTGraphTraits::DefaultDOTGraphTraits; - static std::string getNodeLabel(const Block *Block, Function *); + static std::string getNodeLabel(Block *Block, Function *); }; -std::string llvm::DOTGraphTraits::getNodeLabel(const Block *Block, +std::string llvm::DOTGraphTraits::getNodeLabel(Block *Block, Function *) { // Reuse the print output for the node labels. std::string outStreamStr;