diff --git a/mlir/include/mlir/IR/BasicBlock.h b/mlir/include/mlir/IR/BasicBlock.h index 18585586e30a..5b42a68774b8 100644 --- a/mlir/include/mlir/IR/BasicBlock.h +++ b/mlir/include/mlir/IR/BasicBlock.h @@ -19,6 +19,7 @@ #define MLIR_IR_BASICBLOCK_H #include "mlir/IR/Instructions.h" +#include namespace mlir { @@ -27,9 +28,11 @@ namespace mlir { /// /// Basic blocks form a graph (the CFG) which can be traversed through /// predecessor and successor edges. -class BasicBlock { +class BasicBlock + : public llvm::ilist_node_with_parent { public: - explicit BasicBlock(CFGFunction *function); + explicit BasicBlock(); + ~BasicBlock(); /// Return the function that a BasicBlock is part of. CFGFunction *getFunction() const { @@ -38,23 +41,101 @@ public: // TODO: bb arguments - // TODO: Wrong representation. - std::vector instList; + /// Unlink this BasicBlock from its CFGFunction and delete it. + void eraseFromFunction(); - void setTerminator(TerminatorInst *inst) { - terminator = inst; + //===--------------------------------------------------------------------===// + // Operation list management + //===--------------------------------------------------------------------===// + + /// This is the list of operations in the block. + typedef llvm::iplist OperationListType; + OperationListType &getOperations() { return operations; } + const OperationListType &getOperations() const { return operations; } + + // Iteration over the operations in the block. + using iterator = OperationListType::iterator; + using const_iterator = OperationListType::const_iterator; + using reverse_iterator = OperationListType::reverse_iterator; + using const_reverse_iterator = OperationListType::const_reverse_iterator; + + iterator begin() { return operations.begin(); } + iterator end() { return operations.end(); } + const_iterator begin() const { return operations.begin(); } + const_iterator end() const { return operations.end(); } + reverse_iterator rbegin() { return operations.rbegin(); } + reverse_iterator rend() { return operations.rend(); } + const_reverse_iterator rbegin() const { return operations.rbegin(); } + const_reverse_iterator rend() const { return operations.rend(); } + + bool empty() const { return operations.empty(); } + void push_back(OperationInst *inst) { operations.push_back(inst); } + void push_front(OperationInst *inst) { operations.push_front(inst); } + + OperationInst &back() { return operations.back(); } + const OperationInst &back() const { + return const_cast(this)->back(); } + + OperationInst &front() { return operations.front(); } + const OperationInst &front() const { + return const_cast(this)->front(); + } + + //===--------------------------------------------------------------------===// + // Terminator management + //===--------------------------------------------------------------------===// + + /// Change the terminator of this block to the specified instruction. + void setTerminator(TerminatorInst *inst); + TerminatorInst *getTerminator() const { return terminator; } void print(raw_ostream &os) const; void dump() const; + /// getSublistAccess() - Returns pointer to member of operation list + static OperationListType BasicBlock::*getSublistAccess(OperationInst*) { + return &BasicBlock::operations; + } + private: - CFGFunction *const function; - // FIXME: wrong representation and API, leaks memory etc. + CFGFunction *function = nullptr; + + /// This is the list of operations in the block. + OperationListType operations; + + /// This is the owning reference to the terminator of the block. TerminatorInst *terminator = nullptr; + + BasicBlock(const BasicBlock&) = delete; + void operator=(const BasicBlock&) = delete; + + friend struct llvm::ilist_traits; }; } // end namespace mlir +//===----------------------------------------------------------------------===// +// ilist_traits for OperationInst +//===----------------------------------------------------------------------===// + +namespace llvm { + +template <> +struct ilist_traits<::mlir::BasicBlock> + : public ilist_alloc_traits<::mlir::BasicBlock> { + using BasicBlock = ::mlir::BasicBlock; + using block_iterator = simple_ilist::iterator; + + void addNodeToList(BasicBlock *block); + void removeNodeFromList(BasicBlock *block); + void transferNodesFromList(ilist_traits &otherList, + block_iterator first, block_iterator last); +private: + mlir::CFGFunction *getContainingFunction(); +}; +} // end namespace llvm + + #endif // MLIR_IR_BASICBLOCK_H diff --git a/mlir/include/mlir/IR/CFGFunction.h b/mlir/include/mlir/IR/CFGFunction.h index 335eccc8a00c..4dcb39112f62 100644 --- a/mlir/include/mlir/IR/CFGFunction.h +++ b/mlir/include/mlir/IR/CFGFunction.h @@ -20,7 +20,6 @@ #include "mlir/IR/Function.h" #include "mlir/IR/BasicBlock.h" -#include namespace mlir { @@ -30,8 +29,52 @@ class CFGFunction : public Function { public: CFGFunction(StringRef name, FunctionType *type); - // FIXME: wrong representation and API, leaks memory etc. - std::vector blockList; + //===--------------------------------------------------------------------===// + // BasicBlock list management + //===--------------------------------------------------------------------===// + + /// This is the list of blocks in the function. + typedef llvm::iplist BasicBlockListType; + BasicBlockListType &getBlocks() { return blocks; } + const BasicBlockListType &getBlocks() const { return blocks; } + + // Iteration over the block in the function. + using iterator = BasicBlockListType::iterator; + using const_iterator = BasicBlockListType::const_iterator; + using reverse_iterator = BasicBlockListType::reverse_iterator; + using const_reverse_iterator = BasicBlockListType::const_reverse_iterator; + + iterator begin() { return blocks.begin(); } + iterator end() { return blocks.end(); } + const_iterator begin() const { return blocks.begin(); } + const_iterator end() const { return blocks.end(); } + reverse_iterator rbegin() { return blocks.rbegin(); } + reverse_iterator rend() { return blocks.rend(); } + const_reverse_iterator rbegin() const { return blocks.rbegin(); } + const_reverse_iterator rend() const { return blocks.rend(); } + + bool empty() const { return blocks.empty(); } + void push_back(BasicBlock *block) { blocks.push_back(block); } + void push_front(BasicBlock *block) { blocks.push_front(block); } + + BasicBlock &back() { return blocks.back(); } + const BasicBlock &back() const { + return const_cast(this)->back(); + } + + BasicBlock &front() { return blocks.front(); } + const BasicBlock &front() const { + return const_cast(this)->front(); + } + + /// getSublistAccess() - Returns pointer to member of block list + static BasicBlockListType CFGFunction::*getSublistAccess(BasicBlock*) { + return &CFGFunction::blocks; + } + + //===--------------------------------------------------------------------===// + // Other + //===--------------------------------------------------------------------===// /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Function *func) { @@ -39,6 +82,9 @@ public: } void print(raw_ostream &os) const; + +private: + BasicBlockListType blocks; }; diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index ed4088a0f9de..096ad5acc4aa 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -24,11 +24,16 @@ #include "mlir/Support/LLVM.h" #include "mlir/IR/Identifier.h" +#include "llvm/ADT/ilist.h" +#include "llvm/ADT/ilist_node.h" namespace mlir { + class OperationInst; class BasicBlock; class CFGFunction; +/// Instruction is the root of the operation and terminator instructions in the +/// hierarchy. class Instruction { public: enum class Kind { @@ -47,26 +52,45 @@ public: /// Return the CFGFunction containing this instruction. CFGFunction *getFunction() const; + /// Destroy this instruction or one of its subclasses + static void destroy(Instruction *inst); + void print(raw_ostream &os) const; void dump() const; protected: - Instruction(Kind kind, BasicBlock *block) : kind(kind), block(block) {} + Instruction(Kind kind) : kind(kind) {} + + // Instructions are deleted through the destroy() member because this class + // does not have a virtual destructor. A vtable would bloat the size of + // every instruction by a word, is not necessary given the closed nature of + // instruction kinds. + ~Instruction(); private: Kind kind; - BasicBlock *block; + BasicBlock *block = nullptr; + + friend struct llvm::ilist_traits; + friend class BasicBlock; }; /// Operations are the main instruction kind in MLIR, which represent all of the /// arithmetic and other basic computation that occurs in a CFG function. -class OperationInst : public Instruction { +class OperationInst + : public Instruction, + public llvm::ilist_node_with_parent { public: - explicit OperationInst(Identifier name, BasicBlock *block); + explicit OperationInst(Identifier name) + : Instruction(Kind::Operation), name(name) {} + ~OperationInst() {} Identifier getName() const { return name; } // TODO: Need to have results and operands. + /// Unlink this instruction from its BasicBlock and delete it. + void eraseFromBlock(); + /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Instruction *inst) { return inst->getKind() == Kind::Operation; @@ -86,15 +110,22 @@ public: return inst->getKind() != Kind::Operation; } + /// Remove this terminator from its BasicBlock and delete it. + void eraseFromBlock(); + protected: - TerminatorInst(Kind kind, BasicBlock *block) : Instruction(kind, block) {} + TerminatorInst(Kind kind) : Instruction(kind) {} + ~TerminatorInst() {} }; /// The 'br' instruction is an unconditional from one basic block to another, /// and may pass basic block arguments to the successor. class BranchInst : public TerminatorInst { public: - explicit BranchInst(BasicBlock *dest, BasicBlock *parent); + explicit BranchInst(BasicBlock *dest) + : TerminatorInst(Kind::Branch), dest(dest) { + } + ~BranchInst() {} /// Return the block this branch jumps to. BasicBlock *getDest() const { @@ -118,7 +149,8 @@ private: /// required to align with the result list of the containing function's type. class ReturnInst : public TerminatorInst { public: - explicit ReturnInst(BasicBlock *parent); + explicit ReturnInst() : TerminatorInst(Kind::Return) {} + ~ReturnInst() {} // TODO: Needs to take an operand list. @@ -130,4 +162,30 @@ public: } // end namespace mlir + +//===----------------------------------------------------------------------===// +// ilist_traits for OperationInst +//===----------------------------------------------------------------------===// + +namespace llvm { + +template <> +struct ilist_traits<::mlir::OperationInst> { + using OperationInst = ::mlir::OperationInst; + using instr_iterator = simple_ilist::iterator; + + static void deleteNode(OperationInst *inst) { + OperationInst::destroy(inst); + } + + void addNodeToList(OperationInst *inst); + void removeNodeFromList(OperationInst *inst); + void transferNodesFromList(ilist_traits &otherList, + instr_iterator first, instr_iterator last); +private: + mlir::BasicBlock *getContainingBlock(); +}; + +} // end namespace llvm + #endif // MLIR_IR_INSTRUCTIONS_H diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 067697cdf064..c1c1d947454f 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -94,7 +94,7 @@ public: private: const CFGFunction *function; raw_ostream &os; - DenseMap basicBlockIDs; + DenseMap basicBlockIDs; }; } // end anonymous namespace @@ -103,8 +103,8 @@ CFGFunctionState::CFGFunctionState(const CFGFunction *function, raw_ostream &os) // Each basic block gets a unique ID per function. unsigned blockID = 0; - for (auto *block : function->blockList) - basicBlockIDs[block] = blockID++; + for (auto &block : *function) + basicBlockIDs[&block] = blockID++; } void CFGFunctionState::print() { @@ -112,8 +112,8 @@ void CFGFunctionState::print() { printFunctionSignature(this->getFunction(), os); os << " {\n"; - for (auto *block : function->blockList) - print(block); + for (auto &block : *function) + print(&block); os << "}\n\n"; } @@ -121,8 +121,8 @@ void CFGFunctionState::print(const BasicBlock *block) { os << "bb" << getBBID(block) << ":\n"; // TODO Print arguments. - for (auto inst : block->instList) - print(inst); + for (auto &inst : block->getOperations()) + print(&inst); print(block->getTerminator()); } diff --git a/mlir/lib/IR/BasicBlock.cpp b/mlir/lib/IR/BasicBlock.cpp index 4cfe1622d854..c2c865cba2cd 100644 --- a/mlir/lib/IR/BasicBlock.cpp +++ b/mlir/lib/IR/BasicBlock.cpp @@ -19,6 +19,68 @@ #include "mlir/IR/CFGFunction.h" using namespace mlir; -BasicBlock::BasicBlock(CFGFunction *function) : function(function) { - function->blockList.push_back(this); +BasicBlock::BasicBlock() { +} + +BasicBlock::~BasicBlock() { + if (terminator) + terminator->eraseFromBlock(); +} + +/// Unlink this BasicBlock from its CFGFunction and delete it. +void BasicBlock::eraseFromFunction() { + assert(getFunction() && "BasicBlock has no parent"); + getFunction()->getBlocks().erase(this); +} + +void BasicBlock::setTerminator(TerminatorInst *inst) { + // If we already had a terminator, abandon it. + if (terminator) + terminator->block = nullptr; + + // Reset our terminator to the new instruction. + terminator = inst; + if (inst) + inst->block = this; +} + +mlir::CFGFunction * +llvm::ilist_traits<::mlir::BasicBlock>::getContainingFunction() { + size_t Offset( + size_t(&((CFGFunction *)nullptr->*CFGFunction::getSublistAccess(nullptr)))); + iplist *Anchor(static_cast *>(this)); + return reinterpret_cast(reinterpret_cast(Anchor) - + Offset); +} + +/// This is a trait method invoked when a basic block is added to a function. +/// We keep the function pointer up to date. +void llvm::ilist_traits<::mlir::BasicBlock>:: +addNodeToList(BasicBlock *block) { + assert(!block->function && "already in a function!"); + block->function = getContainingFunction(); +} + +/// This is a trait method invoked when an instruction is removed from a +/// function. We keep the function pointer up to date. +void llvm::ilist_traits<::mlir::BasicBlock>:: +removeNodeFromList(BasicBlock *block) { + assert(block->function && "not already in a function!"); + block->function = nullptr; +} + +/// This is a trait method invoked when an instruction is moved from one block +/// to another. We keep the block pointer up to date. +void llvm::ilist_traits<::mlir::BasicBlock>:: +transferNodesFromList(ilist_traits &otherList, + block_iterator first, block_iterator last) { + // If we are transferring instructions within the same function, the parent + // pointer doesn't need to be updated. + CFGFunction *curParent = getContainingFunction(); + if (curParent == otherList.getContainingFunction()) + return; + + // Update the 'function' member of each BasicBlock. + for (; first != last; ++first) + first->function = curParent; } diff --git a/mlir/lib/IR/Instructions.cpp b/mlir/lib/IR/Instructions.cpp index 2222a12c5d12..729936a6a8d7 100644 --- a/mlir/lib/IR/Instructions.cpp +++ b/mlir/lib/IR/Instructions.cpp @@ -23,6 +23,27 @@ using namespace mlir; // Instruction //===----------------------------------------------------------------------===// +// Instructions are deleted through the destroy() member because we don't have +// a virtual destructor. +Instruction::~Instruction() { + assert(block == nullptr && "instruction destroyed but still in a block"); +} + +/// Destroy this instruction or one of its subclasses. +void Instruction::destroy(Instruction *inst) { + switch (inst->getKind()) { + case Kind::Operation: + delete cast(inst); + break; + case Kind::Branch: + delete cast(inst); + break; + case Kind::Return: + delete cast(inst); + break; + } +} + CFGFunction *Instruction::getFunction() const { return getBlock()->getFunction(); } @@ -31,21 +52,64 @@ CFGFunction *Instruction::getFunction() const { // OperationInst //===----------------------------------------------------------------------===// -OperationInst::OperationInst(Identifier name, BasicBlock *block) : - Instruction(Kind::Operation, block), name(name) { - getBlock()->instList.push_back(this); +mlir::BasicBlock * +llvm::ilist_traits<::mlir::OperationInst>::getContainingBlock() { + size_t Offset( + size_t(&((BasicBlock *)nullptr->*BasicBlock::getSublistAccess(nullptr)))); + iplist *Anchor(static_cast *>(this)); + return reinterpret_cast(reinterpret_cast(Anchor) - + Offset); } +/// This is a trait method invoked when an instruction is added to a block. We +/// keep the block pointer up to date. +void llvm::ilist_traits<::mlir::OperationInst>:: +addNodeToList(OperationInst *inst) { + assert(!inst->getBlock() && "already in a basic block!"); + inst->block = getContainingBlock(); +} + +/// This is a trait method invoked when an instruction is removed from a block. +/// We keep the block pointer up to date. +void llvm::ilist_traits<::mlir::OperationInst>:: +removeNodeFromList(OperationInst *inst) { + assert(inst->block && "not already in a basic block!"); + inst->block = nullptr; +} + +/// This is a trait method invoked when an instruction is moved from one block +/// to another. We keep the block pointer up to date. +void llvm::ilist_traits<::mlir::OperationInst>:: +transferNodesFromList(ilist_traits &otherList, + instr_iterator first, instr_iterator last) { + // If we are transferring instructions within the same basic block, the block + // pointer doesn't need to be updated. + BasicBlock *curParent = getContainingBlock(); + if (curParent == otherList.getContainingBlock()) + return; + + // Update the 'block' member of each instruction. + for (; first != last; ++first) + first->block = curParent; +} + +/// Unlink this instruction from its BasicBlock and delete it. +void OperationInst::eraseFromBlock() { + assert(getBlock() && "Instruction has no parent"); + getBlock()->getOperations().erase(this); +} + + + //===----------------------------------------------------------------------===// // Terminators //===----------------------------------------------------------------------===// -ReturnInst::ReturnInst(BasicBlock *parent) - : TerminatorInst(Kind::Return, parent) { - getBlock()->setTerminator(this); +/// Remove this terminator from its BasicBlock and delete it. +void TerminatorInst::eraseFromBlock() { + assert(getBlock() && "Instruction has no parent"); + getBlock()->setTerminator(nullptr); + TerminatorInst::destroy(this); } -BranchInst::BranchInst(BasicBlock *dest, BasicBlock *parent) - : TerminatorInst(Kind::Branch, parent), dest(dest) { - getBlock()->setTerminator(this); -} + diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 1bfa331256dc..91f80e25582a 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -153,10 +153,8 @@ private: ParseResult parseBasicBlock(CFGFunctionParserState &functionState); MLStatement *parseMLStatement(MLFunction *currentFunction); - ParseResult parseCFGOperation(BasicBlock *currentBB, - CFGFunctionParserState &functionState); - ParseResult parseTerminator(BasicBlock *currentBB, - CFGFunctionParserState &functionState); + OperationInst *parseCFGOperation(CFGFunctionParserState &functionState); + TerminatorInst *parseTerminator(CFGFunctionParserState &functionState); }; } // end anonymous namespace @@ -792,7 +790,7 @@ class CFGFunctionParserState { BasicBlock *getBlockNamed(StringRef name, SMLoc loc) { auto &blockAndLoc = blocksByName[name]; if (!blockAndLoc.first) { - blockAndLoc.first = new BasicBlock(function); + blockAndLoc.first = new BasicBlock(); blockAndLoc.second = loc; } return blockAndLoc.first; @@ -834,7 +832,7 @@ ParseResult Parser::parseCFGFunc() { // StringMap isn't determinstic, but this is good enough for our purposes. for (auto &elt : functionState.blocksByName) { auto *bb = elt.second.first; - if (!bb->getTerminator()) + if (!bb->getFunction()) return emitError(elt.second.second, "reference to an undefined basic block '" + elt.first() + "'"); @@ -861,20 +859,11 @@ ParseResult Parser::parseBasicBlock(CFGFunctionParserState &functionState) { // If this block has already been parsed, then this is a redefinition with the // same block name. - if (block->getTerminator()) + if (block->getFunction()) return emitError(nameLoc, "redefinition of block '" + name.str() + "'"); - // References to blocks can occur in any order, but we need to reassemble the - // function in the order that occurs in the source file. Do this by moving - // each block to the end of the list as it is defined. - // FIXME: This is inefficient for large functions given that blockList is a - // vector. blockList will eventually be an ilist, which will make this fast. - auto &blockList = functionState.function->blockList; - if (blockList.back() != block) { - auto it = std::find(blockList.begin(), blockList.end(), block); - assert(it != blockList.end() && "Block has to be in the blockList"); - std::swap(*it, blockList.back()); - } + // Add the block to the function. + functionState.function->push_back(block); // TODO: parse bb argument list. @@ -883,12 +872,17 @@ ParseResult Parser::parseBasicBlock(CFGFunctionParserState &functionState) { // Parse the list of operations that make up the body of the block. while (curToken.isNot(Token::kw_return, Token::kw_br)) { - if (parseCFGOperation(block, functionState)) + auto *inst = parseCFGOperation(functionState); + if (!inst) return ParseFailure; + + block->getOperations().push_back(inst); } - if (parseTerminator(block, functionState)) + auto *term = parseTerminator(functionState); + if (!term) return ParseFailure; + block->setTerminator(term); return ParseSuccess; } @@ -903,33 +897,29 @@ ParseResult Parser::parseBasicBlock(CFGFunctionParserState &functionState) { /// (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict? /// `:` function-type /// -ParseResult Parser:: -parseCFGOperation(BasicBlock *currentBB, - CFGFunctionParserState &functionState) { +OperationInst *Parser:: +parseCFGOperation(CFGFunctionParserState &functionState) { // TODO: parse ssa-id. if (curToken.isNot(Token::string)) - return emitError("expected operation name in quotes"); + return (emitError("expected operation name in quotes"), nullptr); auto name = curToken.getStringValue(); if (name.empty()) - return emitError("empty operation name is invalid"); + return (emitError("empty operation name is invalid"), nullptr); consumeToken(Token::string); if (!consumeIf(Token::l_paren)) - return emitError("expected '(' in operation"); + return (emitError("expected '(' in operation"), nullptr); // TODO: Parse operands. if (!consumeIf(Token::r_paren)) - return emitError("expected '(' in operation"); + return (emitError("expected '(' in operation"), nullptr); auto nameId = Identifier::get(name, context); - new OperationInst(nameId, currentBB); - - // TODO: add instruction the per-function symbol table. - return ParseSuccess; + return new OperationInst(nameId); } @@ -941,25 +931,22 @@ parseCFGOperation(BasicBlock *currentBB, /// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list? /// terminator-stmt ::= `return` ssa-use-and-type-list? /// -ParseResult Parser::parseTerminator(BasicBlock *currentBB, - CFGFunctionParserState &functionState) { +TerminatorInst *Parser::parseTerminator(CFGFunctionParserState &functionState) { switch (curToken.getKind()) { default: - return emitError("expected terminator at end of basic block"); + return (emitError("expected terminator at end of basic block"), nullptr); case Token::kw_return: consumeToken(Token::kw_return); - new ReturnInst(currentBB); - return ParseSuccess; + return new ReturnInst(); case Token::kw_br: { consumeToken(Token::kw_br); auto destBB = functionState.getBlockNamed(curToken.getSpelling(), curToken.getLoc()); if (!consumeIf(Token::bare_identifier)) - return emitError("expected basic block name"); - new BranchInst(destBB, currentBB); - return ParseSuccess; + return (emitError("expected basic block name"), nullptr); + return new BranchInst(destBB); } } }