Refactor MLFunction to contain a StmtBlock for its body instead of inheriting

from it.  This is necessary progress to squaring away the parent relationship
that a StmtBlock has with its enclosing if/for/fn, and makes room for functions
to have more than one block in the future.  This also removes IfClause and ForStmtBody.

This is step 5/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 226936541
This commit is contained in:
Chris Lattner 2018-12-26 11:21:53 -08:00 committed by jpienaar
parent 9a4060d3f5
commit d613f5ab65
20 changed files with 81 additions and 123 deletions

View File

@ -294,6 +294,12 @@ public:
setInsertionPoint(stmt);
}
MLFuncBuilder(StmtBlock *block)
// TODO: Eliminate findFunction from this.
: MLFuncBuilder(block->findFunction()) {
setInsertionPoint(block, block->end());
}
MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
// TODO: Eliminate findFunction from this.
: MLFuncBuilder(block->findFunction()) {
@ -304,7 +310,7 @@ public:
/// the function.
MLFuncBuilder(MLFunction *func)
: Builder(func->getContext()), function(func) {
setInsertionPoint(func, func->begin());
setInsertionPoint(func->getBody(), func->getBody()->begin());
}
/// Return the function this builder is referring to.

View File

@ -36,7 +36,6 @@ template <typename ObjectType, typename ElementType> class ArgumentIterator;
// include nested affine for loops, conditionals and operations.
class MLFunction final
: public Function,
public StmtBlock,
private llvm::TrailingObjects<MLFunction, MLFuncArgument> {
public:
/// Creates a new MLFunction with the specific type.
@ -44,6 +43,9 @@ public:
FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
StmtBlock *getBody() { return &body; }
const StmtBlock *getBody() const { return &body; }
/// Destroys this statement and its subclass data.
void destroy();
@ -98,9 +100,6 @@ public:
static bool classof(const Function *func) {
return func->getKind() == Function::Kind::MLFunc;
}
static bool classof(const StmtBlock *block) {
return block->getStmtBlockKind() == StmtBlockKind::MLFunc;
}
private:
MLFunction(Location location, StringRef name, FunctionType type,
@ -119,6 +118,8 @@ private:
MutableArrayRef<MLFuncArgument> getArgumentsInternal() {
return {getTrailingObjects<MLFuncArgument>(), getNumArguments()};
}
StmtBlock body;
};
//===--------------------------------------------------------------------===//

View File

@ -274,29 +274,6 @@ private:
size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; }
};
/// A ForStmtBody represents statements contained within a ForStmt.
class ForStmtBody : public StmtBlock {
public:
explicit ForStmtBody(ForStmt *stmt)
: StmtBlock(StmtBlockKind::ForBody), forStmt(stmt) {
assert(stmt != nullptr && "ForStmtBody must have non-null parent");
}
~ForStmtBody() {}
/// Methods for support type inquiry through isa, cast, and dyn_cast
static bool classof(const StmtBlock *block) {
return block->getStmtBlockKind() == StmtBlockKind::ForBody;
}
/// Returns the 'for' statement that contains this body.
ForStmt *getFor() { return forStmt; }
const ForStmt *getFor() const { return forStmt; }
private:
ForStmt *forStmt;
};
/// For statement represents an affine loop nest.
class ForStmt : public Statement, public MLValue {
public:
@ -324,10 +301,10 @@ public:
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
/// Get the body of the ForStmt.
ForStmtBody *getBody() { return &body; }
StmtBlock *getBody() { return &body; }
/// Get the body of the ForStmt.
const ForStmtBody *getBody() const { return &body; }
const StmtBlock *getBody() const { return &body; }
//===--------------------------------------------------------------------===//
// Bounds and step
@ -455,7 +432,7 @@ public:
private:
// The StmtBlock for the body.
ForStmtBody body;
StmtBlock body;
// Affine map for the lower bound.
AffineMap lbMap;
@ -525,31 +502,6 @@ private:
friend class ForStmt;
};
/// An if clause represents statements contained within a then or an else clause
/// of an if statement.
class IfClause : public StmtBlock {
public:
explicit IfClause(IfStmt *stmt)
: StmtBlock(StmtBlockKind::IfClause), ifStmt(stmt) {
assert(stmt != nullptr && "If clause must have non-null parent");
}
/// Methods for support type inquiry through isa, cast, and dyn_cast
static bool classof(const StmtBlock *block) {
return block->getStmtBlockKind() == StmtBlockKind::IfClause;
}
~IfClause() {}
/// Returns the if statement that contains this clause.
const IfStmt *getIf() const { return ifStmt; }
IfStmt *getIf() { return ifStmt; }
private:
IfStmt *ifStmt;
};
/// If statement restricts execution to a subset of the loop iteration space.
class IfStmt : public Statement {
public:
@ -561,15 +513,15 @@ public:
// Then, else, condition.
//===--------------------------------------------------------------------===//
IfClause *getThen() { return &thenClause; }
const IfClause *getThen() const { return &thenClause; }
IfClause *getElse() { return elseClause; }
const IfClause *getElse() const { return elseClause; }
StmtBlock *getThen() { return &thenClause; }
const StmtBlock *getThen() const { return &thenClause; }
StmtBlock *getElse() { return elseClause; }
const StmtBlock *getElse() const { return elseClause; }
bool hasElse() const { return elseClause != nullptr; }
IfClause *createElse() {
StmtBlock *createElse() {
assert(elseClause == nullptr && "already has an else clause!");
return (elseClause = new IfClause(this));
return (elseClause = new StmtBlock(this));
}
const AffineCondition getCondition() const;
@ -634,9 +586,9 @@ public:
private:
// it is always present.
IfClause thenClause;
StmtBlock thenClause;
// 'else' clause of the if statement. 'nullptr' if there is no else clause.
IfClause *elseClause;
StmtBlock *elseClause;
// The integer set capturing the conditional guard.
IntegerSet set;

View File

@ -39,12 +39,8 @@ template <typename BlockType> class StmtSuccessorIterator;
/// children of a parent statement in the ML Function.
class StmtBlock : public IRObjectWithUseList {
public:
enum class StmtBlockKind {
MLFunc, // MLFunction
ForBody, // ForStmtBody
IfClause // IfClause
};
explicit StmtBlock(MLFunction *parent);
explicit StmtBlock(Statement *parent);
~StmtBlock();
void clear() {
@ -54,7 +50,9 @@ public:
statements.pop_back();
}
StmtBlockKind getStmtBlockKind() const { return kind; }
llvm::PointerUnion<MLFunction *, Statement *> getParent() const {
return parent;
}
/// Returns the closest surrounding statement that contains this block or
/// nullptr if this is a top-level statement block.
@ -66,7 +64,10 @@ public:
/// Returns the function that this statement block is part of.
/// The function is determined by traversing the chain of parent statements.
MLFunction *findFunction() const;
MLFunction *findFunction();
const MLFunction *findFunction() const {
return const_cast<StmtBlock *>(this)->findFunction();
}
//===--------------------------------------------------------------------===//
// Block argument management
@ -224,11 +225,10 @@ public:
void printBlock(raw_ostream &os) const;
void dumpBlock() const;
protected:
StmtBlock(StmtBlockKind kind) : kind(kind) {}
private:
StmtBlockKind kind;
/// This is the parent function/IfStmt/ForStmt that owns this block.
llvm::PointerUnion<MLFunction *, Statement *> parent;
/// This is the list of statements in the block.
StmtListType statements;

View File

@ -132,11 +132,13 @@ public:
// Define walkers for MLFunction and all MLFunction statement kinds.
void walk(MLFunction *f) {
static_cast<SubClass *>(this)->visitMLFunction(f);
static_cast<SubClass *>(this)->walk(f->begin(), f->end());
static_cast<SubClass *>(this)->walk(f->getBody()->begin(),
f->getBody()->end());
}
void walkPostOrder(MLFunction *f) {
static_cast<SubClass *>(this)->walkPostOrder(f->begin(), f->end());
static_cast<SubClass *>(this)->walkPostOrder(f->getBody()->begin(),
f->getBody()->end());
static_cast<SubClass *>(this)->visitMLFunction(f);
}

View File

@ -288,8 +288,8 @@ bool MLFuncVerifier::verifyDominance() {
HashTable::ScopeTy blockScope(liveValues);
// The induction variable of a for statement is live within its body.
if (auto *forStmtBody = dyn_cast<ForStmtBody>(&block))
liveValues.insert(forStmtBody->getFor(), true);
if (auto *forStmt = dyn_cast_or_null<ForStmt>(block.getContainingStmt()))
liveValues.insert(forStmt, true);
for (auto &stmt : block) {
// Verify that each of the operands are live.
@ -330,16 +330,16 @@ bool MLFuncVerifier::verifyDominance() {
};
// Check the whole function out.
return walkBlock(fn);
return walkBlock(*fn.getBody());
}
bool MLFuncVerifier::verifyReturn() {
// TODO: fold return verification in the pass that verifies all statements.
const char missingReturnMsg[] = "ML function must end with return statement";
if (fn.getStatements().empty())
if (fn.getBody()->getStatements().empty())
return failure(missingReturnMsg, fn);
const auto &stmt = fn.getStatements().back();
const auto &stmt = fn.getBody()->getStatements().back();
if (const auto *op = dyn_cast<OperationStmt>(&stmt)) {
if (!op->isReturn())
return failure(missingReturnMsg, fn);

View File

@ -230,7 +230,7 @@ void ModuleState::visitStatement(const Statement *stmt) {
void ModuleState::visitMLFunction(const MLFunction *fn) {
visitType(fn->getType());
for (auto &stmt : *fn) {
for (auto &stmt : *fn->getBody()) {
ModuleState::visitStatement(&stmt);
}
}
@ -1390,7 +1390,7 @@ void MLFunctionPrinter::print() {
printFunctionSignature();
printFunctionAttributes(getFunction());
os << " {\n";
print(function);
print(function->getBody());
os << "}\n\n";
}
@ -1649,7 +1649,7 @@ void Statement::print(raw_ostream &os) const {
void Statement::dump() const { print(llvm::errs()); }
void StmtBlock::printBlock(raw_ostream &os) const {
MLFunction *function = findFunction();
const MLFunction *function = findFunction();
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
MLFunctionPrinter(function, modulePrinter).print(this);

View File

@ -474,7 +474,7 @@ void ReturnOp::print(OpAsmPrinter *p) const {
bool ReturnOp::verify() const {
const Function *function;
if (auto *stmt = dyn_cast<OperationStmt>(getOperation()))
function = cast<MLFunction>(stmt->getBlock());
function = stmt->getBlock()->findFunction();
else
function = cast<Instruction>(getOperation())->getFunction();

View File

@ -202,14 +202,13 @@ MLFunction *MLFunction::create(Location location, StringRef name,
MLFunction::MLFunction(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::MLFunc, location, name, type, attrs),
StmtBlock(StmtBlockKind::MLFunc) {}
: Function(Kind::MLFunc, location, name, type, attrs), body(this) {}
MLFunction::~MLFunction() {
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
// since child statements need to be destroyed before function arguments
// are destroyed.
clear();
getBody()->clear();
// Explicitly run the destructors for the function arguments.
for (auto &arg : getArgumentsInternal())
@ -222,11 +221,11 @@ void MLFunction::destroy() {
}
const OperationStmt *MLFunction::getReturnStmt() const {
return cast<OperationStmt>(&back());
return cast<OperationStmt>(&getBody()->back());
}
OperationStmt *MLFunction::getReturnStmt() {
return cast<OperationStmt>(&back());
return cast<OperationStmt>(&getBody()->back());
}
void MLFunction::walk(std::function<void(OperationStmt *)> callback) {

View File

@ -581,7 +581,7 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
// Verify that the operation is at the end of the respective parent block.
if (auto *stmt = dyn_cast<OperationStmt>(op)) {
StmtBlock *block = stmt->getBlock();
if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
if (!block || block->getContainingStmt() || &block->back() != stmt)
return op->emitOpError("must be the last statement in the ML function");
} else {
const Instruction *inst = cast<Instruction>(op);

View File

@ -20,33 +20,30 @@
#include "mlir/IR/Statements.h"
using namespace mlir;
StmtBlock::StmtBlock(MLFunction *parent) : parent(parent) {}
StmtBlock::StmtBlock(Statement *parent) : parent(parent) {}
StmtBlock::~StmtBlock() {
clear();
llvm::DeleteContainerPointers(arguments);
}
/// Returns the closest surrounding statement that contains this block or
/// nullptr if this is a top-level statement block.
Statement *StmtBlock::getContainingStmt() {
switch (kind) {
case StmtBlockKind::MLFunc:
return nullptr;
case StmtBlockKind::ForBody:
return cast<ForStmtBody>(this)->getFor();
case StmtBlockKind::IfClause:
return cast<IfClause>(this)->getIf();
}
return parent.dyn_cast<Statement *>();
}
MLFunction *StmtBlock::findFunction() const {
// FIXME: const incorrect.
StmtBlock *block = const_cast<StmtBlock *>(this);
while (block->getContainingStmt()) {
block = block->getContainingStmt()->getBlock();
MLFunction *StmtBlock::findFunction() {
StmtBlock *block = this;
while (auto *stmt = block->getContainingStmt()) {
block = stmt->getBlock();
if (!block)
return nullptr;
}
return dyn_cast<MLFunction>(block);
return block->getParent().get<MLFunction *>();
}
/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor

View File

@ -2777,7 +2777,7 @@ class MLFunctionParser : public FunctionParser {
public:
MLFunctionParser(ParserState &state, MLFunction *function)
: FunctionParser(state, Kind::MLFunc), function(function),
builder(function, function->end()) {}
builder(function->getBody()) {}
ParseResult parseFunctionBody();
@ -2796,7 +2796,7 @@ private:
ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap &map,
bool isLower);
ParseResult parseIfStmt();
ParseResult parseElseClause(IfClause *elseClause);
ParseResult parseElseClause(StmtBlock *elseClause);
ParseResult parseStatements(StmtBlock *block);
ParseResult parseStmtBlock(StmtBlock *block);
@ -2812,7 +2812,7 @@ ParseResult MLFunctionParser::parseFunctionBody() {
auto braceLoc = getToken().getLoc();
// Parse statements in this function.
if (parseStmtBlock(function))
if (parseStmtBlock(function->getBody()))
return ParseFailure;
return finalizeFunction(function, braceLoc);
@ -3121,7 +3121,7 @@ ParseResult MLFunctionParser::parseIfStmt() {
IfStmt *ifStmt =
builder.createIf(getEncodedSourceLocation(loc), operands, set);
IfClause *thenClause = ifStmt->getThen();
StmtBlock *thenClause = ifStmt->getThen();
// When parsing of an if statement body fails, the IR contains
// the if statement with the portion of the body that has been
@ -3141,7 +3141,7 @@ ParseResult MLFunctionParser::parseIfStmt() {
return ParseSuccess;
}
ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) {
ParseResult MLFunctionParser::parseElseClause(StmtBlock *elseClause) {
if (getToken().is(Token::kw_if)) {
builder.setInsertionPointToEnd(elseClause);
return parseIfStmt();

View File

@ -490,7 +490,7 @@ CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) {
}
// Convert statements in order.
for (auto &stmt : *mlFunc) {
for (auto &stmt : *mlFunc->getBody()) {
visit(&stmt);
}

View File

@ -426,7 +426,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
}
PassResult DmaGeneration::runOnMLFunction(MLFunction *f) {
for (auto &stmt : *f) {
for (auto &stmt : *f->getBody()) {
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
runOnForStmt(forStmt);
}

View File

@ -348,7 +348,7 @@ public:
bool MemRefDependenceGraph::init(MLFunction *f) {
unsigned id = 0;
DenseMap<MLValue *, SetVector<unsigned>> memrefAccesses;
for (auto &stmt : *f) {
for (auto &stmt : *f->getBody()) {
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
// Create graph node 'id' to represent top-level 'forStmt' and record
// all loads and store accesses it contains.

View File

@ -230,8 +230,8 @@ static void getTileableBands(MLFunction *f,
bands->push_back(band);
};
for (auto &stmt : *f) {
ForStmt *forStmt = dyn_cast<ForStmt>(&stmt);
for (auto &stmt : *f->getBody()) {
auto *forStmt = dyn_cast<ForStmt>(&stmt);
if (!forStmt)
continue;
getMaximalPerfectLoopNest(forStmt);

View File

@ -92,10 +92,10 @@ PassResult LoopUnrollAndJam::runOnMLFunction(MLFunction *f) {
// Currently, just the outermost loop from the first loop nest is
// unroll-and-jammed by this pass. However, runOnForStmt can be called on any
// for Stmt.
if (!isa<ForStmt>(f->begin()))
auto *forStmt = dyn_cast<ForStmt>(f->getBody()->begin());
if (!forStmt)
return success();
auto *forStmt = cast<ForStmt>(f->begin());
runOnForStmt(forStmt);
return success();
}

View File

@ -238,7 +238,7 @@ struct LowerVectorTransfersPass
makeFuncWiseState(MLFunction *f) const override {
auto state = llvm::make_unique<LowerVectorTransfersState>();
auto builder = MLFuncBuilder(f);
builder.setInsertionPointToStart(f);
builder.setInsertionPointToStart(f->getBody());
state->zero = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), 0);
return state;
}

View File

@ -177,7 +177,8 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
cast<Instruction>(op)->moveBefore(&entryBB, entryBB.begin());
} else {
auto *mlFunc = cast<MLFunction>(currentFunction);
cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
cast<OperationStmt>(op)->moveBefore(mlFunc->getBody(),
mlFunc->getBody()->begin());
}
continue;

View File

@ -102,7 +102,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
if (!forStmt->use_empty()) {
if (forStmt->hasConstantLowerBound()) {
auto *mlFunc = forStmt->findFunction();
MLFuncBuilder topBuilder(&mlFunc->front());
MLFuncBuilder topBuilder(&mlFunc->getBody()->front());
auto constOp = topBuilder.create<ConstantIndexOp>(
forStmt->getLoc(), forStmt->getConstantLowerBound());
forStmt->replaceAllUsesWith(constOp);