Introduce a new StmtBlockList type to hold a list of StmtBlocks. Use it in

MLFunction, IfStmt, ForStmt even though they currently only contain exactly one
block in that list.

This is step 6/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 226960278
This commit is contained in:
Chris Lattner 2018-12-26 15:31:54 -08:00 committed by jpienaar
parent 63068da4d9
commit 5ff0001dc7
7 changed files with 192 additions and 29 deletions

View File

@ -145,7 +145,7 @@ struct ilist_traits<::mlir::Function>
using Function = ::mlir::Function;
using function_iterator = simple_ilist<Function>::iterator;
static void deleteNode(Function *inst) { inst->destroy(); }
static void deleteNode(Function *function) { function->destroy(); }
void addNodeToList(Function *function);
void removeNodeFromList(Function *function);

View File

@ -43,8 +43,11 @@ public:
FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
StmtBlock *getBody() { return &body; }
const StmtBlock *getBody() const { return &body; }
StmtBlockList &getStatementList() { return body; }
const StmtBlockList &getStatementList() const { return body; }
StmtBlock *getBody() { return &body.front(); }
const StmtBlock *getBody() const { return &body.front(); }
/// Destroys this statement and its subclass data.
void destroy();
@ -119,7 +122,7 @@ private:
return {getTrailingObjects<MLFuncArgument>(), getNumArguments()};
}
StmtBlock body;
StmtBlockList body;
};
//===--------------------------------------------------------------------===//

View File

@ -301,10 +301,10 @@ public:
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
/// Get the body of the ForStmt.
StmtBlock *getBody() { return &body; }
StmtBlock *getBody() { return &body.front(); }
/// Get the body of the ForStmt.
const StmtBlock *getBody() const { return &body; }
const StmtBlock *getBody() const { return &body.front(); }
//===--------------------------------------------------------------------===//
// Bounds and step
@ -432,7 +432,7 @@ public:
private:
// The StmtBlock for the body.
StmtBlock body;
StmtBlockList body;
// Affine map for the lower bound.
AffineMap lbMap;
@ -513,15 +513,19 @@ public:
// Then, else, condition.
//===--------------------------------------------------------------------===//
StmtBlock *getThen() { return &thenClause; }
const StmtBlock *getThen() const { return &thenClause; }
StmtBlock *getElse() { return elseClause; }
const StmtBlock *getElse() const { return elseClause; }
StmtBlock *getThen() { return &thenClause.front(); }
const StmtBlock *getThen() const { return &thenClause.front(); }
StmtBlock *getElse() { return elseClause ? &elseClause->front() : nullptr; }
const StmtBlock *getElse() const {
return elseClause ? &elseClause->front() : nullptr;
}
bool hasElse() const { return elseClause != nullptr; }
StmtBlock *createElse() {
assert(elseClause == nullptr && "already has an else clause!");
return (elseClause = new StmtBlock(this));
elseClause = new StmtBlockList(this);
elseClause->push_back(new StmtBlock());
return &elseClause->front();
}
const AffineCondition getCondition() const;
@ -586,9 +590,9 @@ public:
private:
// it is always present.
StmtBlock thenClause;
StmtBlockList thenClause;
// 'else' clause of the if statement. 'nullptr' if there is no else clause.
StmtBlock *elseClause;
StmtBlockList *elseClause;
// The integer set capturing the conditional guard.
IntegerSet set;

View File

@ -28,6 +28,7 @@ namespace mlir {
class MLFunction;
class IfStmt;
class MLValue;
class StmtBlockList;
// TODO(clattner): drop the Stmt prefixes on these once BasicBlock's versions of
// these go away.
@ -37,10 +38,11 @@ template <typename BlockType> class StmtSuccessorIterator;
/// Statement block represents an ordered list of statements, with the order
/// being the contiguous lexical order in which the statements appear as
/// children of a parent statement in the ML Function.
class StmtBlock : public IRObjectWithUseList {
class StmtBlock
: public IRObjectWithUseList,
public llvm::ilist_node_with_parent<StmtBlock, StmtBlockList> {
public:
explicit StmtBlock(MLFunction *parent);
explicit StmtBlock(Statement *parent);
explicit StmtBlock() {}
~StmtBlock();
void clear() {
@ -50,9 +52,7 @@ public:
statements.pop_back();
}
llvm::PointerUnion<MLFunction *, Statement *> getParent() const {
return parent;
}
StmtBlockList *getParent() const { return parent; }
/// Returns the closest surrounding statement that contains this block or
/// nullptr if this is a top-level statement block.
@ -227,7 +227,7 @@ public:
private:
/// This is the parent function/IfStmt/ForStmt that owns this block.
llvm::PointerUnion<MLFunction *, Statement *> parent;
StmtBlockList *parent = nullptr;
/// This is the list of statements in the block.
StmtListType statements;
@ -237,6 +237,100 @@ private:
StmtBlock(const StmtBlock &) = delete;
void operator=(const StmtBlock &) = delete;
friend struct llvm::ilist_traits<StmtBlock>;
};
} // end namespace mlir
//===----------------------------------------------------------------------===//
// ilist_traits for StmtBlock
//===----------------------------------------------------------------------===//
namespace llvm {
template <>
struct ilist_traits<::mlir::StmtBlock>
: public ilist_alloc_traits<::mlir::StmtBlock> {
using StmtBlock = ::mlir::StmtBlock;
using block_iterator = simple_ilist<::mlir::StmtBlock>::iterator;
void addNodeToList(StmtBlock *block);
void removeNodeFromList(StmtBlock *block);
void transferNodesFromList(ilist_traits<StmtBlock> &otherList,
block_iterator first, block_iterator last);
private:
mlir::StmtBlockList *getContainingBlockList();
};
} // end namespace llvm
namespace mlir {
/// This class contains a list of basic blocks and has a notion of the object it
/// is part of - an MLFunction or IfStmt or ForStmt.
class StmtBlockList {
public:
explicit StmtBlockList(MLFunction *container);
explicit StmtBlockList(Statement *container);
using BlockListType = llvm::iplist<StmtBlock>;
BlockListType &getBlocks() { return blocks; }
const BlockListType &getBlocks() const { return blocks; }
// Iteration over the block in the function.
using iterator = BlockListType::iterator;
using const_iterator = BlockListType::const_iterator;
using reverse_iterator = BlockListType::reverse_iterator;
using const_reverse_iterator = BlockListType::const_reverse_iterator;
iterator begin() { return blocks.begin(); }
iterator end() { return blocks.end(); }
const_iterator begin() const { return blocks.begin(); }
const_iterator end() const { return blocks.end(); }
reverse_iterator rbegin() { return blocks.rbegin(); }
reverse_iterator rend() { return blocks.rend(); }
const_reverse_iterator rbegin() const { return blocks.rbegin(); }
const_reverse_iterator rend() const { return blocks.rend(); }
bool empty() const { return blocks.empty(); }
void push_back(StmtBlock *block) { blocks.push_back(block); }
void push_front(StmtBlock *block) { blocks.push_front(block); }
StmtBlock &back() { return blocks.back(); }
const StmtBlock &back() const {
return const_cast<StmtBlockList *>(this)->back();
}
StmtBlock &front() { return blocks.front(); }
const StmtBlock &front() const {
return const_cast<StmtBlockList *>(this)->front();
}
/// getSublistAccess() - Returns pointer to member of block list.
static BlockListType StmtBlockList::*getSublistAccess(StmtBlock *) {
return &StmtBlockList::blocks;
}
/// A StmtBlockList is part of a MLFunction or and IfStmt/ForStmt. If it is
/// part of an IfStmt/ForStmt, then return it, otherwise return null.
Statement *getContainingStmt();
const Statement *getContainingStmt() const {
return const_cast<StmtBlockList *>(this)->getContainingStmt();
}
/// A StmtBlockList is part of a MLFunction or and IfStmt/ForStmt. If it is
/// part of an MLFunction, then return it, otherwise return null.
MLFunction *getContainingFunction();
const MLFunction *getContainingFunction() const {
return const_cast<StmtBlockList *>(this)->getContainingFunction();
}
private:
BlockListType blocks;
/// This is the object we are part of.
llvm::PointerUnion<MLFunction *, Statement *> container;
};
//===----------------------------------------------------------------------===//
@ -369,5 +463,5 @@ inline auto StmtBlock::getSuccessors() -> llvm::iterator_range<succ_iterator> {
return {succ_begin(), succ_end()};
}
} //end namespace mlir
} // end namespace mlir
#endif // MLIR_IR_STMTBLOCK_H

View File

@ -202,7 +202,10 @@ MLFunction *MLFunction::create(Location location, StringRef name,
MLFunction::MLFunction(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::MLFunc, location, name, type, attrs), body(this) {}
: Function(Kind::MLFunc, location, name, type, attrs), body(this) {
// The body of an MLFunction always has one block.
body.push_back(new StmtBlock());
}
MLFunction::~MLFunction() {
// Explicitly erase statements instead of relying of 'StmtBlock' destructor

View File

@ -405,6 +405,9 @@ ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
MLValue(MLValueKind::ForStmt,
Type::getIndex(lbMap.getResult(0).getContext())),
body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
// The body of a for stmt always has one block.
body.push_back(new StmtBlock());
operands.reserve(numOperands);
}
@ -522,6 +525,9 @@ IfStmt::IfStmt(Location location, unsigned numOperands, IntegerSet set)
: Statement(Kind::If, location), thenClause(this), elseClause(nullptr),
set(set) {
operands.reserve(numOperands);
// The then of an 'if' stmt always has one block.
thenClause.push_back(new StmtBlock());
}
IfStmt::~IfStmt() {

View File

@ -20,10 +20,6 @@
#include "mlir/IR/Statements.h"
using namespace mlir;
StmtBlock::StmtBlock(MLFunction *parent) : parent(parent) {}
StmtBlock::StmtBlock(Statement *parent) : parent(parent) {}
StmtBlock::~StmtBlock() {
clear();
@ -33,7 +29,7 @@ StmtBlock::~StmtBlock() {
/// Returns the closest surrounding statement that contains this block or
/// nullptr if this is a top-level statement block.
Statement *StmtBlock::getContainingStmt() {
return parent.dyn_cast<Statement *>();
return parent ? parent->getContainingStmt() : nullptr;
}
MLFunction *StmtBlock::findFunction() {
@ -43,7 +39,9 @@ MLFunction *StmtBlock::findFunction() {
if (!block)
return nullptr;
}
return block->getParent().get<MLFunction *>();
if (auto *list = block->getParent())
return list->getContainingFunction();
return nullptr;
}
/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor
@ -140,3 +138,58 @@ StmtBlock *StmtBlock::getSinglePredecessor() {
++it;
return it == pred_end() ? firstPred : nullptr;
}
//===----------------------------------------------------------------------===//
// StmtBlockList
//===----------------------------------------------------------------------===//
StmtBlockList::StmtBlockList(MLFunction *container) : container(container) {}
StmtBlockList::StmtBlockList(Statement *container) : container(container) {}
Statement *StmtBlockList::getContainingStmt() {
return container.dyn_cast<Statement *>();
}
MLFunction *StmtBlockList::getContainingFunction() {
return container.dyn_cast<MLFunction *>();
}
StmtBlockList *llvm::ilist_traits<::mlir::StmtBlock>::getContainingBlockList() {
size_t Offset(size_t(
&((StmtBlockList *)nullptr->*StmtBlockList::getSublistAccess(nullptr))));
iplist<StmtBlock> *Anchor(static_cast<iplist<StmtBlock> *>(this));
return reinterpret_cast<StmtBlockList *>(reinterpret_cast<char *>(Anchor) -
Offset);
}
/// This is a trait method invoked when a basic block is added to a function.
/// We keep the function pointer up to date.
void llvm::ilist_traits<::mlir::StmtBlock>::addNodeToList(StmtBlock *block) {
assert(!block->parent && "already in a function!");
block->parent = getContainingBlockList();
}
/// This is a trait method invoked when an instruction is removed from a
/// function. We keep the function pointer up to date.
void llvm::ilist_traits<::mlir::StmtBlock>::removeNodeFromList(
StmtBlock *block) {
assert(block->parent && "not already in a function!");
block->parent = nullptr;
}
/// This is a trait method invoked when an instruction is moved from one block
/// to another. We keep the block pointer up to date.
void llvm::ilist_traits<::mlir::StmtBlock>::transferNodesFromList(
ilist_traits<StmtBlock> &otherList, block_iterator first,
block_iterator last) {
// If we are transferring instructions within the same function, the parent
// pointer doesn't need to be updated.
auto *curParent = getContainingBlockList();
if (curParent == otherList.getContainingBlockList())
return;
// Update the 'parent' member of each StmtBlock.
for (; first != last; ++first)
first->parent = curParent;
}