llvm-project/mlir/lib/IR/StmtBlock.cpp

239 lines
8.4 KiB
C++

//===- StmtBlock.cpp - MLIR Statement Instruction Classes -----------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/StmtBlock.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
using namespace mlir;
StmtBlock::~StmtBlock() {
clear();
llvm::DeleteContainerPointers(arguments);
}
/// Returns the closest surrounding statement that contains this block or
/// nullptr if this is a top-level statement block.
Statement *StmtBlock::getContainingStmt() {
return parent ? parent->getContainingStmt() : nullptr;
}
MLFunction *StmtBlock::getFunction() {
StmtBlock *block = this;
while (auto *stmt = block->getContainingStmt()) {
block = stmt->getBlock();
if (!block)
return nullptr;
}
if (auto *list = block->getParent())
return list->getContainingFunction();
return nullptr;
}
/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor
/// statement of 'stmt' that lies in this block. Returns nullptr if the latter
/// fails.
const Statement *
StmtBlock::findAncestorStmtInBlock(const Statement &stmt) const {
// Traverse up the statement hierarchy starting from the owner of operand to
// find the ancestor statement that resides in the block of 'forStmt'.
const auto *currStmt = &stmt;
while (currStmt->getBlock() != this) {
currStmt = currStmt->getParentStmt();
if (!currStmt)
return nullptr;
}
return currStmt;
}
//===----------------------------------------------------------------------===//
// Argument list management.
//===----------------------------------------------------------------------===//
BlockArgument *StmtBlock::addArgument(Type type) {
auto *arg = new BlockArgument(type, this);
arguments.push_back(arg);
return arg;
}
/// Add one argument to the argument list for each type specified in the list.
auto StmtBlock::addArguments(ArrayRef<Type> types)
-> llvm::iterator_range<args_iterator> {
arguments.reserve(arguments.size() + types.size());
auto initialSize = arguments.size();
for (auto type : types) {
addArgument(type);
}
return {arguments.data() + initialSize, arguments.data() + arguments.size()};
}
void StmtBlock::eraseArgument(unsigned index) {
assert(index < arguments.size());
// Delete the argument.
delete arguments[index];
arguments.erase(arguments.begin() + index);
// Erase this argument from each of the predecessor's terminator.
for (auto predIt = pred_begin(), predE = pred_end(); predIt != predE;
++predIt) {
auto *predTerminator = (*predIt)->getTerminator();
predTerminator->eraseSuccessorOperand(predIt.getSuccessorIndex(), index);
}
}
//===----------------------------------------------------------------------===//
// Terminator management
//===----------------------------------------------------------------------===//
OperationStmt *StmtBlock::getTerminator() {
if (empty())
return nullptr;
// Check if the last instruction is a terminator.
auto &backInst = statements.back();
auto *opStmt = dyn_cast<OperationStmt>(&backInst);
if (!opStmt || !opStmt->isTerminator())
return nullptr;
return opStmt;
}
/// Return true if this block has no predecessors.
bool StmtBlock::hasNoPredecessors() const { return pred_begin() == pred_end(); }
// Indexed successor access.
unsigned StmtBlock::getNumSuccessors() const {
return getTerminator()->getNumSuccessors();
}
StmtBlock *StmtBlock::getSuccessor(unsigned i) {
return getTerminator()->getSuccessor(i);
}
/// If this block has exactly one predecessor, return it. Otherwise, return
/// null.
///
/// Note that multiple edges from a single block (e.g. if you have a cond
/// branch with the same block as the true/false destinations) is not
/// considered to be a single predecessor.
StmtBlock *StmtBlock::getSinglePredecessor() {
auto it = pred_begin();
if (it == pred_end())
return nullptr;
auto *firstPred = *it;
++it;
return it == pred_end() ? firstPred : nullptr;
}
//===----------------------------------------------------------------------===//
// Other
//===----------------------------------------------------------------------===//
/// Unlink this BasicBlock from its CFGFunction and delete it.
void BasicBlock::eraseFromFunction() {
assert(getFunction() && "BasicBlock has no parent");
getFunction()->getBlocks().erase(this);
}
/// Split the basic block into two basic blocks before the specified
/// instruction or iterator.
///
/// Note that all instructions BEFORE the specified iterator stay as part of
/// the original basic block, an unconditional branch is added to the original
/// block (going to the new block), and the rest of the instructions in the
/// original block are moved to the new BB, including the old terminator. The
/// newly formed BasicBlock is returned.
///
/// This function invalidates the specified iterator.
BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) {
// Start by creating a new basic block, and insert it immediate after this
// one in the containing function.
auto newBB = new BasicBlock();
getFunction()->getBlocks().insert(++CFGFunction::iterator(this), newBB);
auto branchLoc =
splitBefore == end() ? getTerminator()->getLoc() : splitBefore->getLoc();
// Move all of the operations from the split point to the end of the function
// into the new block.
newBB->getStatements().splice(newBB->end(), getStatements(), splitBefore,
end());
// Create an unconditional branch to the new block, and move our terminator
// to the new block.
CFGFuncBuilder(this).create<BranchOp>(branchLoc, newBB);
return newBB;
}
//===----------------------------------------------------------------------===//
// StmtBlockList
//===----------------------------------------------------------------------===//
StmtBlockList::StmtBlockList(Function *container) : container(container) {}
StmtBlockList::StmtBlockList(Statement *container) : container(container) {}
CFGFunction *StmtBlockList::getFunction() {
return dyn_cast_or_null<CFGFunction>(getContainingFunction());
}
Statement *StmtBlockList::getContainingStmt() {
return container.dyn_cast<Statement *>();
}
Function *StmtBlockList::getContainingFunction() {
return container.dyn_cast<Function *>();
}
StmtBlockList *llvm::ilist_traits<::mlir::StmtBlock>::getContainingBlockList() {
size_t Offset(size_t(
&((StmtBlockList *)nullptr->*StmtBlockList::getSublistAccess(nullptr))));
iplist<StmtBlock> *Anchor(static_cast<iplist<StmtBlock> *>(this));
return reinterpret_cast<StmtBlockList *>(reinterpret_cast<char *>(Anchor) -
Offset);
}
/// This is a trait method invoked when a basic block is added to a function.
/// We keep the function pointer up to date.
void llvm::ilist_traits<::mlir::StmtBlock>::addNodeToList(StmtBlock *block) {
assert(!block->parent && "already in a function!");
block->parent = getContainingBlockList();
}
/// This is a trait method invoked when an instruction is removed from a
/// function. We keep the function pointer up to date.
void llvm::ilist_traits<::mlir::StmtBlock>::removeNodeFromList(
StmtBlock *block) {
assert(block->parent && "not already in a function!");
block->parent = nullptr;
}
/// This is a trait method invoked when an instruction is moved from one block
/// to another. We keep the block pointer up to date.
void llvm::ilist_traits<::mlir::StmtBlock>::transferNodesFromList(
ilist_traits<StmtBlock> &otherList, block_iterator first,
block_iterator last) {
// If we are transferring instructions within the same function, the parent
// pointer doesn't need to be updated.
auto *curParent = getContainingBlockList();
if (curParent == otherList.getContainingBlockList())
return;
// Update the 'parent' member of each StmtBlock.
for (; first != last; ++first)
first->parent = curParent;
}