From 8f60c4ad7325c47aadd49cc1961800c34c6c88e4 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Thu, 26 Jul 2018 08:56:26 -0700 Subject: [PATCH] Implement the groundwork for predecessor/successor iterators on basic blocks. Give BasicBlock a use/def list, making references to them in TerminatorInst's into a type that maintains the list. PiperOrigin-RevId: 206166388 --- mlir/include/mlir/IR/BasicBlock.h | 3 +- mlir/include/mlir/IR/CFGValue.h | 2 +- mlir/include/mlir/IR/Instructions.h | 52 +++++++++++++++----- mlir/include/mlir/IR/SSAOperand.h | 69 ++++++++++++++------------- mlir/include/mlir/IR/SSAValue.h | 74 ++++++++++++++++++----------- mlir/lib/IR/Function.cpp | 1 + mlir/lib/IR/Instructions.cpp | 48 ++++++++++++++++++- 7 files changed, 169 insertions(+), 80 deletions(-) diff --git a/mlir/include/mlir/IR/BasicBlock.h b/mlir/include/mlir/IR/BasicBlock.h index c26203ca8fab..ce8df72049bf 100644 --- a/mlir/include/mlir/IR/BasicBlock.h +++ b/mlir/include/mlir/IR/BasicBlock.h @@ -30,7 +30,8 @@ class BBArgument; /// Basic blocks form a graph (the CFG) which can be traversed through /// predecessor and successor edges. class BasicBlock - : public llvm::ilist_node_with_parent { + : public IRObjectWithUseList, + public llvm::ilist_node_with_parent { public: explicit BasicBlock(); ~BasicBlock(); diff --git a/mlir/include/mlir/IR/CFGValue.h b/mlir/include/mlir/IR/CFGValue.h index 5dd070be0d9d..8dd635cfd2e7 100644 --- a/mlir/include/mlir/IR/CFGValue.h +++ b/mlir/include/mlir/IR/CFGValue.h @@ -39,7 +39,7 @@ enum class CFGValueKind { }; /// The operand of a CFG Instruction contains a CFGValue. -using InstOperand = SSAOperandImpl; +using InstOperand = IROperandImpl; /// CFGValue is the base class for SSA values in CFG functions. class CFGValue : public SSAValueImpl { diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index becc39a6834d..da41825a8a70 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -30,9 +30,13 @@ #include "llvm/Support/TrailingObjects.h" namespace mlir { -class OperationInst; class BasicBlock; class CFGFunction; +class OperationInst; +class TerminatorInst; + +/// The operand of a CFG Instruction contains a CFGValue. +using BBDestination = IROperandImpl; /// Instruction is the root of the operation and terminator instructions in the /// hierarchy. @@ -292,6 +296,21 @@ public: /// Remove this terminator from its BasicBlock and delete it. void eraseFromBlock(); + /// Return the list of destination entries that this terminator branches to. + MutableArrayRef getDestinations(); + + ArrayRef getDestinations() const { + return const_cast(this)->getDestinations(); + } + + unsigned getNumSuccessors() const { return getDestinations().size(); } + + const BasicBlock *getSuccessor(unsigned i) const { + return getDestinations()[i].get(); + } + + BasicBlock *getSuccessor(unsigned i) { return getDestinations()[i].get(); } + protected: TerminatorInst(Kind kind) : Instruction(kind) {} ~TerminatorInst() {} @@ -305,7 +324,8 @@ public: ~BranchInst() {} /// Return the block this branch jumps to. - BasicBlock *getDest() const { return dest; } + BasicBlock *getDest() const { return dest.get(); } + void setDest(BasicBlock *block); unsigned getNumOperands() const { return operands.size(); } @@ -321,16 +341,18 @@ public: /// Erase a specific argument from the arg list. // TODO: void eraseArgument(int Index); + MutableArrayRef getDestinations() { return dest; } + ArrayRef getDestinations() const { return dest; } + /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Instruction *inst) { return inst->getKind() == Kind::Branch; } private: - explicit BranchInst(BasicBlock *dest) - : TerminatorInst(Kind::Branch), dest(dest) {} + explicit BranchInst(BasicBlock *dest); - BasicBlock *dest; + BBDestination dest; std::vector operands; }; @@ -338,6 +360,9 @@ private: /// condition to one of two possible successors. It may pass arguments to each /// successor. class CondBranchInst : public TerminatorInst { + // These are the indices into the dests list. + enum { trueIndex = 0, falseIndex = 1 }; + public: static CondBranchInst *create(CFGValue *condition, BasicBlock *trueDest, BasicBlock *falseDest) { @@ -350,10 +375,10 @@ public: const CFGValue *getCondition() const { return condition; } /// Return the destination if the condition is true. - BasicBlock *getTrueDest() const { return trueDest; } + BasicBlock *getTrueDest() const { return dests[trueIndex].get(); } /// Return the destination if the condition is false. - BasicBlock *getFalseDest() const { return falseDest; } + BasicBlock *getFalseDest() const { return dests[falseIndex].get(); } // Support non-const operand iteration. using operand_iterator = OperandIterator; @@ -476,20 +501,21 @@ public: /// Add a list of values to the operand list. void addFalseOperands(ArrayRef values); + MutableArrayRef getDestinations() { return dests; } + ArrayRef getDestinations() const { return dests; } + /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Instruction *inst) { return inst->getKind() == Kind::CondBranch; } private: - explicit CondBranchInst(CFGValue *condition, BasicBlock *trueDest, - BasicBlock *falseDest) - : TerminatorInst(Kind::CondBranch), condition(condition), - trueDest(trueDest), falseDest(falseDest), numTrueOperands(0) {} + CondBranchInst(CFGValue *condition, BasicBlock *trueDest, + BasicBlock *falseDest); CFGValue *condition; - BasicBlock *trueDest; - BasicBlock *falseDest; + BBDestination dests[2]; // 0 is the true dest, 1 is the false dest. + // Operand list. The true operands are stored first, followed by the false // operands. std::vector operands; diff --git a/mlir/include/mlir/IR/SSAOperand.h b/mlir/include/mlir/IR/SSAOperand.h index 7685c310d8ce..5f47cfba7bc6 100644 --- a/mlir/include/mlir/IR/SSAOperand.h +++ b/mlir/include/mlir/IR/SSAOperand.h @@ -29,16 +29,16 @@ class SSAValue; /// A reference to a value, suitable for use as an operand of an instruction, /// statement, etc. -class SSAOperand { +class IROperand { public: - SSAOperand() {} - SSAOperand(SSAValue *value) : value(value) { insertIntoCurrent(); } + IROperand() {} + IROperand(IRObjectWithUseList *value) : value(value) { insertIntoCurrent(); } /// Return the current value being used by this operand. - SSAValue *get() const { return value; } + IRObjectWithUseList *get() const { return value; } /// Set the current value being used by this operand. - void set(SSAValue *newValue) { + void set(IRObjectWithUseList *newValue) { // It isn't worth optimizing for the case of switching operands on a single // value. removeFromCurrent(); @@ -54,16 +54,16 @@ public: back = nullptr; } - ~SSAOperand() { removeFromCurrent(); } + ~IROperand() { removeFromCurrent(); } /// Return the next operand on the use-list of the value we are referring to. /// This should generally only be used by the internal implementation details /// of the SSA machinery. - SSAOperand *getNextOperandUsingThisValue() { return nextUse; } + IROperand *getNextOperandUsingThisValue() { return nextUse; } - /// We support a move constructor so SSAOperands can be in vectors, but this + /// We support a move constructor so IROperand's can be in vectors, but this /// shouldn't be used by general clients. - SSAOperand(SSAOperand &&other) { + IROperand(IROperand &&other) { other.removeFromCurrent(); value = other.value; other.value = nullptr; @@ -75,17 +75,17 @@ public: private: /// The value used as this operand. This can be null when in a /// "dropAllUses" state. - SSAValue *value = nullptr; + IRObjectWithUseList *value = nullptr; /// The next operand in the use-chain. - SSAOperand *nextUse = nullptr; + IROperand *nextUse = nullptr; /// This points to the previous link in the use-chain. - SSAOperand **back = nullptr; + IROperand **back = nullptr; /// Operands are not copyable or assignable. - SSAOperand(const SSAOperand &use) = delete; - SSAOperand &operator=(const SSAOperand &use) = delete; + IROperand(const IROperand &use) = delete; + IROperand &operator=(const IROperand &use) = delete; void removeFromCurrent() { if (!back) @@ -105,52 +105,53 @@ private: }; /// A reference to a value, suitable for use as an operand of an instruction, -/// statement, etc. SSAValueTy is the root type to use for values this tracks, +/// statement, etc. IRValueTy is the root type to use for values this tracks, /// and SSAUserTy is the type that will contain operands. -template -class SSAOperandImpl : public SSAOperand { +template +class IROperandImpl : public IROperand { public: - SSAOperandImpl(SSAOwnerTy *owner) : owner(owner) {} - SSAOperandImpl(SSAOwnerTy *owner, SSAValueTy *value) - : SSAOperand(value), owner(owner) {} + IROperandImpl(IROwnerTy *owner) : owner(owner) {} + IROperandImpl(IROwnerTy *owner, IRValueTy *value) + : IROperand(value), owner(owner) {} /// Return the current value being used by this operand. - SSAValueTy *get() const { return (SSAValueTy *)SSAOperand::get(); } + IRValueTy *get() const { return (IRValueTy *)IROperand::get(); } /// Set the current value being used by this operand. - void set(SSAValueTy *newValue) { SSAOperand::set(newValue); } + void set(IRValueTy *newValue) { IROperand::set(newValue); } /// Return the user that owns this use. - SSAOwnerTy *getOwner() { return owner; } - const SSAOwnerTy *getOwner() const { return owner; } + IROwnerTy *getOwner() { return owner; } + const IROwnerTy *getOwner() const { return owner; } /// Return which operand this is in the operand list of the User. // TODO: unsigned getOperandNumber() const; - /// We support a move constructor so SSAOperands can be in vectors, but this + /// We support a move constructor so IROperand's can be in vectors, but this /// shouldn't be used by general clients. - SSAOperandImpl(SSAOperandImpl &&other) - : SSAOperand(std::move(other)), owner(other.owner) {} + IROperandImpl(IROperandImpl &&other) + : IROperand(std::move(other)), owner(other.owner) {} private: /// The owner of this operand. - SSAOwnerTy *const owner; + IROwnerTy *const owner; }; -inline auto SSAValue::use_begin() const -> use_iterator { - return SSAValue::use_iterator(firstUse); +inline auto IRObjectWithUseList::use_begin() const -> use_iterator { + return use_iterator(firstUse); } -inline auto SSAValue::use_end() const -> use_iterator { - return SSAValue::use_iterator(nullptr); +inline auto IRObjectWithUseList::use_end() const -> use_iterator { + return use_iterator(nullptr); } -inline auto SSAValue::getUses() const -> llvm::iterator_range { +inline auto IRObjectWithUseList::getUses() const + -> llvm::iterator_range { return {use_begin(), use_end()}; } /// Returns true if this value has exactly one use. -inline bool SSAValue::hasOneUse() const { +inline bool IRObjectWithUseList::hasOneUse() const { return firstUse && firstUse->getNextOperandUsingThisValue() == nullptr; } diff --git a/mlir/include/mlir/IR/SSAValue.h b/mlir/include/mlir/IR/SSAValue.h index 0b7648ca2c0e..f4fd833ebab3 100644 --- a/mlir/include/mlir/IR/SSAValue.h +++ b/mlir/include/mlir/IR/SSAValue.h @@ -29,9 +29,43 @@ namespace mlir { class OperationInst; -class SSAOperand; +class IROperand; template class SSAValueUseIterator; +class IRObjectWithUseList { +public: + ~IRObjectWithUseList() { + assert(use_empty() && "Cannot destroy a value that still has uses!"); + } + + /// Returns true if this value has no uses. + bool use_empty() const { return firstUse == nullptr; } + + /// Returns true if this value has exactly one use. + inline bool hasOneUse() const; + + using use_iterator = SSAValueUseIterator; + using use_range = llvm::iterator_range; + + inline use_iterator use_begin() const; + inline use_iterator use_end() const; + + /// Returns a range of all uses, which is useful for iterating over all uses. + inline use_range getUses() const; + + /// Replace all uses of 'this' value with the new value, updating anything in + /// the IR that uses 'this' to use the other value instead. When this returns + /// there are zero uses of 'this'. + void replaceAllUsesWith(IRObjectWithUseList *newValue); + +protected: + IRObjectWithUseList() {} + +private: + friend class IROperand; + IROperand *firstUse = nullptr; +}; + /// This enumerates all of the SSA value kinds in the MLIR system. enum class SSAValueKind { BBArgument, @@ -45,35 +79,20 @@ enum class SSAValueKind { /// This is the common base class for all values in the MLIR system, /// representing a computable value that has a type and a set of users. /// -class SSAValue { +class SSAValue : public IRObjectWithUseList { public: - ~SSAValue() { - assert(use_empty() && "Cannot destroy a value that still has uses!"); - } + ~SSAValue() {} SSAValueKind getKind() const { return typeAndKind.getInt(); } Type *getType() const { return typeAndKind.getPointer(); } - /// Returns true if this value has no uses. - bool use_empty() const { return firstUse == nullptr; } - - /// Returns true if this value has exactly one use. - inline bool hasOneUse() const; - - using use_iterator = SSAValueUseIterator; - using use_range = llvm::iterator_range; - - inline use_iterator use_begin() const; - inline use_iterator use_end() const; - - /// Returns a range of all uses, which is useful for iterating over all uses. - inline use_range getUses() const; - /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns /// there are zero uses of 'this'. - void replaceAllUsesWith(SSAValue *newValue); + void replaceAllUsesWith(SSAValue *newValue) { + IRObjectWithUseList::replaceAllUsesWith(newValue); + } /// If this value is the result of an OperationInst, return the instruction /// that defines it. @@ -84,27 +103,24 @@ public: protected: SSAValue(SSAValueKind kind, Type *type) : typeAndKind(type, kind) {} - private: - friend class SSAOperand; const llvm::PointerIntPair typeAndKind; - SSAOperand *firstUse = nullptr; }; /// This template unifies the implementation logic for CFGValue and StmtValue /// while providing more type-specific APIs when walking use lists etc. /// -/// SSAOperandTy is the concrete instance of SSAOperand to use (including +/// IROperandTy is the concrete instance of IROperand to use (including /// substituted template arguments) and KindTy is the enum 'kind' discriminator /// that subclasses want to use. /// -template +template class SSAValueImpl : public SSAValue { public: // Provide more specific implementations of the base class functionality. KindTy getKind() const { return (KindTy)SSAValue::getKind(); } - // TODO: using use_iterator = SSAValueUseIterator; + // TODO: using use_iterator = SSAValueUseIterator; // TODO: using use_range = llvm::iterator_range; // TODO: inline use_iterator use_begin() const; @@ -122,10 +138,10 @@ protected: /// An iterator over all uses of a ValueBase. template class SSAValueUseIterator - : public std::iterator { + : public std::iterator { public: SSAValueUseIterator() = default; - explicit SSAValueUseIterator(SSAOperand *current) : current(current) {} + explicit SSAValueUseIterator(IROperand *current) : current(current) {} OperandType *operator->() const { return current; } OperandType &operator*() const { return current; } diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 8476b0601a53..b06be4d6f856 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -108,6 +108,7 @@ CFGFunction::~CFGFunction() { for (auto &bb : *this) { for (auto &inst : bb) inst.dropAllReferences(); + bb.getTerminator()->dropAllReferences(); } } diff --git a/mlir/lib/IR/Instructions.cpp b/mlir/lib/IR/Instructions.cpp index a1cdcb3dc33e..7afaf6827520 100644 --- a/mlir/lib/IR/Instructions.cpp +++ b/mlir/lib/IR/Instructions.cpp @@ -22,7 +22,7 @@ using namespace mlir; /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns /// there are zero uses of 'this'. -void SSAValue::replaceAllUsesWith(SSAValue *newValue) { +void IRObjectWithUseList::replaceAllUsesWith(IRObjectWithUseList *newValue) { assert(this != newValue && "cannot RAUW a value with itself"); while (!use_empty()) { use_begin()->set(newValue); @@ -105,6 +105,10 @@ MutableArrayRef Instruction::getInstOperands() { void Instruction::dropAllReferences() { for (auto &op : getInstOperands()) op.drop(); + + if (auto *term = dyn_cast(this)) + for (auto &dest : term->getDestinations()) + dest.drop(); } //===----------------------------------------------------------------------===// @@ -209,7 +213,7 @@ OperationInst *SSAValue::getDefiningInst() { } //===----------------------------------------------------------------------===// -// Terminators +// TerminatorInst //===----------------------------------------------------------------------===// /// Remove this terminator from its BasicBlock and delete it. @@ -219,6 +223,25 @@ void TerminatorInst::eraseFromBlock() { destroy(); } +/// Return the list of destination entries that this terminator branches to. +MutableArrayRef TerminatorInst::getDestinations() { + switch (getKind()) { + case Kind::Operation: + assert(0 && "not a terminator"); + case Kind::Branch: + return cast(this)->getDestinations(); + case Kind::CondBranch: + return cast(this)->getDestinations(); + case Kind::Return: + // Return has no basic block successors. + return {}; + } +} + +//===----------------------------------------------------------------------===// +// ReturnInst +//===----------------------------------------------------------------------===// + /// Create a new OperationInst with the specific fields. ReturnInst *ReturnInst::create(ArrayRef operands) { auto byteSize = totalSizeToAlloc(operands.size()); @@ -245,6 +268,15 @@ ReturnInst::~ReturnInst() { operand.~InstOperand(); } +//===----------------------------------------------------------------------===// +// BranchInst +//===----------------------------------------------------------------------===// + +BranchInst::BranchInst(BasicBlock *dest) + : TerminatorInst(Kind::Branch), dest(this, dest) {} + +void BranchInst::setDest(BasicBlock *block) { dest.set(block); } + /// Add one value to the operand list. void BranchInst::addOperand(CFGValue *value) { operands.emplace_back(InstOperand(this, value)); @@ -257,6 +289,18 @@ void BranchInst::addOperands(ArrayRef values) { addOperand(value); } +//===----------------------------------------------------------------------===// +// CondBranchInst +//===----------------------------------------------------------------------===// + +CondBranchInst::CondBranchInst(CFGValue *condition, BasicBlock *trueDest, + BasicBlock *falseDest) + : TerminatorInst(Kind::CondBranch), + condition(condition), dests{{this}, {this}}, numTrueOperands(0) { + dests[falseIndex].set(falseDest); + dests[trueIndex].set(trueDest); +} + /// Add one value to the true operand list. void CondBranchInst::addTrueOperand(CFGValue *value) { assert(getNumFalseOperands() == 0 &&