forked from OSchip/llvm-project
Introduce a new StmtBlockList type to hold a list of StmtBlocks. Use it in
MLFunction, IfStmt, ForStmt even though they currently only contain exactly one block in that list. This is step 6/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 226960278
This commit is contained in:
parent
63068da4d9
commit
5ff0001dc7
|
@ -145,7 +145,7 @@ struct ilist_traits<::mlir::Function>
|
|||
using Function = ::mlir::Function;
|
||||
using function_iterator = simple_ilist<Function>::iterator;
|
||||
|
||||
static void deleteNode(Function *inst) { inst->destroy(); }
|
||||
static void deleteNode(Function *function) { function->destroy(); }
|
||||
|
||||
void addNodeToList(Function *function);
|
||||
void removeNodeFromList(Function *function);
|
||||
|
|
|
@ -43,8 +43,11 @@ public:
|
|||
FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
|
||||
StmtBlock *getBody() { return &body; }
|
||||
const StmtBlock *getBody() const { return &body; }
|
||||
StmtBlockList &getStatementList() { return body; }
|
||||
const StmtBlockList &getStatementList() const { return body; }
|
||||
|
||||
StmtBlock *getBody() { return &body.front(); }
|
||||
const StmtBlock *getBody() const { return &body.front(); }
|
||||
|
||||
/// Destroys this statement and its subclass data.
|
||||
void destroy();
|
||||
|
@ -119,7 +122,7 @@ private:
|
|||
return {getTrailingObjects<MLFuncArgument>(), getNumArguments()};
|
||||
}
|
||||
|
||||
StmtBlock body;
|
||||
StmtBlockList body;
|
||||
};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -301,10 +301,10 @@ public:
|
|||
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
|
||||
|
||||
/// Get the body of the ForStmt.
|
||||
StmtBlock *getBody() { return &body; }
|
||||
StmtBlock *getBody() { return &body.front(); }
|
||||
|
||||
/// Get the body of the ForStmt.
|
||||
const StmtBlock *getBody() const { return &body; }
|
||||
const StmtBlock *getBody() const { return &body.front(); }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Bounds and step
|
||||
|
@ -432,7 +432,7 @@ public:
|
|||
|
||||
private:
|
||||
// The StmtBlock for the body.
|
||||
StmtBlock body;
|
||||
StmtBlockList body;
|
||||
|
||||
// Affine map for the lower bound.
|
||||
AffineMap lbMap;
|
||||
|
@ -513,15 +513,19 @@ public:
|
|||
// Then, else, condition.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
StmtBlock *getThen() { return &thenClause; }
|
||||
const StmtBlock *getThen() const { return &thenClause; }
|
||||
StmtBlock *getElse() { return elseClause; }
|
||||
const StmtBlock *getElse() const { return elseClause; }
|
||||
StmtBlock *getThen() { return &thenClause.front(); }
|
||||
const StmtBlock *getThen() const { return &thenClause.front(); }
|
||||
StmtBlock *getElse() { return elseClause ? &elseClause->front() : nullptr; }
|
||||
const StmtBlock *getElse() const {
|
||||
return elseClause ? &elseClause->front() : nullptr;
|
||||
}
|
||||
bool hasElse() const { return elseClause != nullptr; }
|
||||
|
||||
StmtBlock *createElse() {
|
||||
assert(elseClause == nullptr && "already has an else clause!");
|
||||
return (elseClause = new StmtBlock(this));
|
||||
elseClause = new StmtBlockList(this);
|
||||
elseClause->push_back(new StmtBlock());
|
||||
return &elseClause->front();
|
||||
}
|
||||
|
||||
const AffineCondition getCondition() const;
|
||||
|
@ -586,9 +590,9 @@ public:
|
|||
|
||||
private:
|
||||
// it is always present.
|
||||
StmtBlock thenClause;
|
||||
StmtBlockList thenClause;
|
||||
// 'else' clause of the if statement. 'nullptr' if there is no else clause.
|
||||
StmtBlock *elseClause;
|
||||
StmtBlockList *elseClause;
|
||||
|
||||
// The integer set capturing the conditional guard.
|
||||
IntegerSet set;
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace mlir {
|
|||
class MLFunction;
|
||||
class IfStmt;
|
||||
class MLValue;
|
||||
class StmtBlockList;
|
||||
|
||||
// TODO(clattner): drop the Stmt prefixes on these once BasicBlock's versions of
|
||||
// these go away.
|
||||
|
@ -37,10 +38,11 @@ template <typename BlockType> class StmtSuccessorIterator;
|
|||
/// Statement block represents an ordered list of statements, with the order
|
||||
/// being the contiguous lexical order in which the statements appear as
|
||||
/// children of a parent statement in the ML Function.
|
||||
class StmtBlock : public IRObjectWithUseList {
|
||||
class StmtBlock
|
||||
: public IRObjectWithUseList,
|
||||
public llvm::ilist_node_with_parent<StmtBlock, StmtBlockList> {
|
||||
public:
|
||||
explicit StmtBlock(MLFunction *parent);
|
||||
explicit StmtBlock(Statement *parent);
|
||||
explicit StmtBlock() {}
|
||||
~StmtBlock();
|
||||
|
||||
void clear() {
|
||||
|
@ -50,9 +52,7 @@ public:
|
|||
statements.pop_back();
|
||||
}
|
||||
|
||||
llvm::PointerUnion<MLFunction *, Statement *> getParent() const {
|
||||
return parent;
|
||||
}
|
||||
StmtBlockList *getParent() const { return parent; }
|
||||
|
||||
/// Returns the closest surrounding statement that contains this block or
|
||||
/// nullptr if this is a top-level statement block.
|
||||
|
@ -227,7 +227,7 @@ public:
|
|||
|
||||
private:
|
||||
/// This is the parent function/IfStmt/ForStmt that owns this block.
|
||||
llvm::PointerUnion<MLFunction *, Statement *> parent;
|
||||
StmtBlockList *parent = nullptr;
|
||||
|
||||
/// This is the list of statements in the block.
|
||||
StmtListType statements;
|
||||
|
@ -237,6 +237,100 @@ private:
|
|||
|
||||
StmtBlock(const StmtBlock &) = delete;
|
||||
void operator=(const StmtBlock &) = delete;
|
||||
|
||||
friend struct llvm::ilist_traits<StmtBlock>;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ilist_traits for StmtBlock
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace llvm {
|
||||
|
||||
template <>
|
||||
struct ilist_traits<::mlir::StmtBlock>
|
||||
: public ilist_alloc_traits<::mlir::StmtBlock> {
|
||||
using StmtBlock = ::mlir::StmtBlock;
|
||||
using block_iterator = simple_ilist<::mlir::StmtBlock>::iterator;
|
||||
|
||||
void addNodeToList(StmtBlock *block);
|
||||
void removeNodeFromList(StmtBlock *block);
|
||||
void transferNodesFromList(ilist_traits<StmtBlock> &otherList,
|
||||
block_iterator first, block_iterator last);
|
||||
|
||||
private:
|
||||
mlir::StmtBlockList *getContainingBlockList();
|
||||
};
|
||||
} // end namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// This class contains a list of basic blocks and has a notion of the object it
|
||||
/// is part of - an MLFunction or IfStmt or ForStmt.
|
||||
class StmtBlockList {
|
||||
public:
|
||||
explicit StmtBlockList(MLFunction *container);
|
||||
explicit StmtBlockList(Statement *container);
|
||||
|
||||
using BlockListType = llvm::iplist<StmtBlock>;
|
||||
BlockListType &getBlocks() { return blocks; }
|
||||
const BlockListType &getBlocks() const { return blocks; }
|
||||
|
||||
// Iteration over the block in the function.
|
||||
using iterator = BlockListType::iterator;
|
||||
using const_iterator = BlockListType::const_iterator;
|
||||
using reverse_iterator = BlockListType::reverse_iterator;
|
||||
using const_reverse_iterator = BlockListType::const_reverse_iterator;
|
||||
|
||||
iterator begin() { return blocks.begin(); }
|
||||
iterator end() { return blocks.end(); }
|
||||
const_iterator begin() const { return blocks.begin(); }
|
||||
const_iterator end() const { return blocks.end(); }
|
||||
reverse_iterator rbegin() { return blocks.rbegin(); }
|
||||
reverse_iterator rend() { return blocks.rend(); }
|
||||
const_reverse_iterator rbegin() const { return blocks.rbegin(); }
|
||||
const_reverse_iterator rend() const { return blocks.rend(); }
|
||||
|
||||
bool empty() const { return blocks.empty(); }
|
||||
void push_back(StmtBlock *block) { blocks.push_back(block); }
|
||||
void push_front(StmtBlock *block) { blocks.push_front(block); }
|
||||
|
||||
StmtBlock &back() { return blocks.back(); }
|
||||
const StmtBlock &back() const {
|
||||
return const_cast<StmtBlockList *>(this)->back();
|
||||
}
|
||||
|
||||
StmtBlock &front() { return blocks.front(); }
|
||||
const StmtBlock &front() const {
|
||||
return const_cast<StmtBlockList *>(this)->front();
|
||||
}
|
||||
|
||||
/// getSublistAccess() - Returns pointer to member of block list.
|
||||
static BlockListType StmtBlockList::*getSublistAccess(StmtBlock *) {
|
||||
return &StmtBlockList::blocks;
|
||||
}
|
||||
|
||||
/// A StmtBlockList is part of a MLFunction or and IfStmt/ForStmt. If it is
|
||||
/// part of an IfStmt/ForStmt, then return it, otherwise return null.
|
||||
Statement *getContainingStmt();
|
||||
const Statement *getContainingStmt() const {
|
||||
return const_cast<StmtBlockList *>(this)->getContainingStmt();
|
||||
}
|
||||
|
||||
/// A StmtBlockList is part of a MLFunction or and IfStmt/ForStmt. If it is
|
||||
/// part of an MLFunction, then return it, otherwise return null.
|
||||
MLFunction *getContainingFunction();
|
||||
const MLFunction *getContainingFunction() const {
|
||||
return const_cast<StmtBlockList *>(this)->getContainingFunction();
|
||||
}
|
||||
|
||||
private:
|
||||
BlockListType blocks;
|
||||
|
||||
/// This is the object we are part of.
|
||||
llvm::PointerUnion<MLFunction *, Statement *> container;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -369,5 +463,5 @@ inline auto StmtBlock::getSuccessors() -> llvm::iterator_range<succ_iterator> {
|
|||
return {succ_begin(), succ_end()};
|
||||
}
|
||||
|
||||
} //end namespace mlir
|
||||
} // end namespace mlir
|
||||
#endif // MLIR_IR_STMTBLOCK_H
|
||||
|
|
|
@ -202,7 +202,10 @@ MLFunction *MLFunction::create(Location location, StringRef name,
|
|||
|
||||
MLFunction::MLFunction(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs)
|
||||
: Function(Kind::MLFunc, location, name, type, attrs), body(this) {}
|
||||
: Function(Kind::MLFunc, location, name, type, attrs), body(this) {
|
||||
// The body of an MLFunction always has one block.
|
||||
body.push_back(new StmtBlock());
|
||||
}
|
||||
|
||||
MLFunction::~MLFunction() {
|
||||
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
|
||||
|
|
|
@ -405,6 +405,9 @@ ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
|
|||
MLValue(MLValueKind::ForStmt,
|
||||
Type::getIndex(lbMap.getResult(0).getContext())),
|
||||
body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
|
||||
|
||||
// The body of a for stmt always has one block.
|
||||
body.push_back(new StmtBlock());
|
||||
operands.reserve(numOperands);
|
||||
}
|
||||
|
||||
|
@ -522,6 +525,9 @@ IfStmt::IfStmt(Location location, unsigned numOperands, IntegerSet set)
|
|||
: Statement(Kind::If, location), thenClause(this), elseClause(nullptr),
|
||||
set(set) {
|
||||
operands.reserve(numOperands);
|
||||
|
||||
// The then of an 'if' stmt always has one block.
|
||||
thenClause.push_back(new StmtBlock());
|
||||
}
|
||||
|
||||
IfStmt::~IfStmt() {
|
||||
|
|
|
@ -20,10 +20,6 @@
|
|||
#include "mlir/IR/Statements.h"
|
||||
using namespace mlir;
|
||||
|
||||
StmtBlock::StmtBlock(MLFunction *parent) : parent(parent) {}
|
||||
|
||||
StmtBlock::StmtBlock(Statement *parent) : parent(parent) {}
|
||||
|
||||
StmtBlock::~StmtBlock() {
|
||||
clear();
|
||||
|
||||
|
@ -33,7 +29,7 @@ StmtBlock::~StmtBlock() {
|
|||
/// Returns the closest surrounding statement that contains this block or
|
||||
/// nullptr if this is a top-level statement block.
|
||||
Statement *StmtBlock::getContainingStmt() {
|
||||
return parent.dyn_cast<Statement *>();
|
||||
return parent ? parent->getContainingStmt() : nullptr;
|
||||
}
|
||||
|
||||
MLFunction *StmtBlock::findFunction() {
|
||||
|
@ -43,7 +39,9 @@ MLFunction *StmtBlock::findFunction() {
|
|||
if (!block)
|
||||
return nullptr;
|
||||
}
|
||||
return block->getParent().get<MLFunction *>();
|
||||
if (auto *list = block->getParent())
|
||||
return list->getContainingFunction();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor
|
||||
|
@ -140,3 +138,58 @@ StmtBlock *StmtBlock::getSinglePredecessor() {
|
|||
++it;
|
||||
return it == pred_end() ? firstPred : nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StmtBlockList
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
StmtBlockList::StmtBlockList(MLFunction *container) : container(container) {}
|
||||
|
||||
StmtBlockList::StmtBlockList(Statement *container) : container(container) {}
|
||||
|
||||
Statement *StmtBlockList::getContainingStmt() {
|
||||
return container.dyn_cast<Statement *>();
|
||||
}
|
||||
|
||||
MLFunction *StmtBlockList::getContainingFunction() {
|
||||
return container.dyn_cast<MLFunction *>();
|
||||
}
|
||||
|
||||
StmtBlockList *llvm::ilist_traits<::mlir::StmtBlock>::getContainingBlockList() {
|
||||
size_t Offset(size_t(
|
||||
&((StmtBlockList *)nullptr->*StmtBlockList::getSublistAccess(nullptr))));
|
||||
iplist<StmtBlock> *Anchor(static_cast<iplist<StmtBlock> *>(this));
|
||||
return reinterpret_cast<StmtBlockList *>(reinterpret_cast<char *>(Anchor) -
|
||||
Offset);
|
||||
}
|
||||
|
||||
/// This is a trait method invoked when a basic block is added to a function.
|
||||
/// We keep the function pointer up to date.
|
||||
void llvm::ilist_traits<::mlir::StmtBlock>::addNodeToList(StmtBlock *block) {
|
||||
assert(!block->parent && "already in a function!");
|
||||
block->parent = getContainingBlockList();
|
||||
}
|
||||
|
||||
/// This is a trait method invoked when an instruction is removed from a
|
||||
/// function. We keep the function pointer up to date.
|
||||
void llvm::ilist_traits<::mlir::StmtBlock>::removeNodeFromList(
|
||||
StmtBlock *block) {
|
||||
assert(block->parent && "not already in a function!");
|
||||
block->parent = nullptr;
|
||||
}
|
||||
|
||||
/// This is a trait method invoked when an instruction is moved from one block
|
||||
/// to another. We keep the block pointer up to date.
|
||||
void llvm::ilist_traits<::mlir::StmtBlock>::transferNodesFromList(
|
||||
ilist_traits<StmtBlock> &otherList, block_iterator first,
|
||||
block_iterator last) {
|
||||
// If we are transferring instructions within the same function, the parent
|
||||
// pointer doesn't need to be updated.
|
||||
auto *curParent = getContainingBlockList();
|
||||
if (curParent == otherList.getContainingBlockList())
|
||||
return;
|
||||
|
||||
// Update the 'parent' member of each StmtBlock.
|
||||
for (; first != last; ++first)
|
||||
first->parent = curParent;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue