Store 'then' clause statements directly in the 'if' statement.

Also a few minor changes.

PiperOrigin-RevId: 213359024
This commit is contained in:
Tatiana Shpeisman 2018-09-17 16:24:43 -07:00 committed by jpienaar
parent 37a3f638ea
commit 52111cefc0
6 changed files with 21 additions and 21 deletions

View File

@ -437,8 +437,10 @@ public:
// Then, else, condition.
//===--------------------------------------------------------------------===//
IfClause *getThen() const { return thenClause; }
IfClause *getElse() const { return elseClause; }
IfClause *getThen() { return &thenClause; }
const IfClause *getThen() const { return &thenClause; }
IfClause *getElse() { return elseClause; }
const IfClause *getElse() const { return elseClause; }
bool hasElse() const { return elseClause != nullptr; }
IfClause *createElse() {
@ -503,9 +505,9 @@ public:
}
private:
// TODO: The 'If' always has an associated 'theClause', we should be able to
// store the IfClause object for it inline to save an extra allocation.
IfClause *thenClause;
// it is always present.
IfClause thenClause;
// 'else' clause of the if statement. 'nullptr' if there is no else clause.
IfClause *elseClause;
// The integer set capturing the conditional guard.

View File

@ -53,7 +53,7 @@ public:
/// Returns the closest surrounding statement that contains this block or
/// nullptr if this is a top-level statement block.
Statement *getParentStmt() const;
Statement *getContainingStmt() const;
/// Returns the function that this statement block is part of.
/// The function is determined by traversing the chain of parent statements.

View File

@ -167,8 +167,9 @@ public:
void walkIfStmtPostOrder(IfStmt *ifStmt) {
static_cast<SubClass *>(this)->walkPostOrder(ifStmt->getThen()->begin(),
ifStmt->getThen()->end());
static_cast<SubClass *>(this)->walkPostOrder(ifStmt->getElse()->begin(),
ifStmt->getElse()->end());
if (ifStmt->hasElse())
static_cast<SubClass *>(this)->walkPostOrder(ifStmt->getElse()->begin(),
ifStmt->getElse()->end());
static_cast<SubClass *>(this)->visitIfStmt(ifStmt);
}

View File

@ -72,14 +72,12 @@ MLIRContext *Statement::getContext() const {
case Kind::For:
return cast<ForStmt>(this)->getContext();
case Kind::If:
// TODO(shpeisman): When if statement has value operands, we can get a
// context from their type.
return findFunction()->getContext();
return cast<IfStmt>(this)->getContext();
}
}
Statement *Statement::getParentStmt() const {
return block ? block->getParentStmt() : nullptr;
return block ? block->getContainingStmt() : nullptr;
}
MLFunction *Statement::findFunction() const {
@ -368,14 +366,12 @@ void ForStmt::setConstantUpperBound(int64_t value) {
//===----------------------------------------------------------------------===//
IfStmt::IfStmt(Location *location, unsigned numOperands, IntegerSet *set)
: Statement(Kind::If, location), thenClause(new IfClause(this)),
elseClause(nullptr), set(set) {
: Statement(Kind::If, location), thenClause(this), elseClause(nullptr),
set(set) {
operands.reserve(numOperands);
}
IfStmt::~IfStmt() {
delete thenClause;
if (elseClause)
delete elseClause;

View File

@ -24,7 +24,7 @@ using namespace mlir;
// Statement block
//===----------------------------------------------------------------------===//
Statement *StmtBlock::getParentStmt() const {
Statement *StmtBlock::getContainingStmt() const {
switch (kind) {
case StmtBlockKind::MLFunc:
return nullptr;
@ -38,8 +38,8 @@ Statement *StmtBlock::getParentStmt() const {
MLFunction *StmtBlock::findFunction() const {
StmtBlock *block = const_cast<StmtBlock *>(this);
while (block->getParentStmt()) {
block = block->getParentStmt()->getBlock();
while (block->getContainingStmt()) {
block = block->getContainingStmt()->getBlock();
if (!block)
return nullptr;
}

View File

@ -101,8 +101,9 @@ PassResult LoopUnroll::runOnMLFunction(MLFunction *f) {
bool walkIfStmtPostOrder(IfStmt *ifStmt) {
bool hasInnerLoops =
walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end());
hasInnerLoops |=
walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
if (ifStmt->hasElse())
hasInnerLoops |=
walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
return hasInnerLoops;
}