forked from OSchip/llvm-project
Rework the cloning infrastructure for statements to be able to take and update
an operand mapping, which simplifies it a bit. Implement cloning for IfStmt, rename getThenClause() to getThen() which is unambiguous and less repetitive in use cases. PiperOrigin-RevId: 207915990
This commit is contained in:
parent
01915ad0a0
commit
8159186f57
|
@ -216,12 +216,18 @@ private:
|
|||
/// statement block.
|
||||
class MLFuncBuilder : public Builder {
|
||||
public:
|
||||
/// Create ML function builder and set insertion point to the given
|
||||
/// statement block, that is, given ML function, for statement or if statement
|
||||
/// clause.
|
||||
MLFuncBuilder(StmtBlock *block)
|
||||
/// Create ML function builder and set insertion point to the given statement,
|
||||
/// which will cause subsequent insertions to go right before it.
|
||||
MLFuncBuilder(Statement *stmt)
|
||||
// TODO: Eliminate findFunction from this.
|
||||
: Builder(stmt->findFunction()->getContext()) {
|
||||
setInsertionPoint(stmt);
|
||||
}
|
||||
|
||||
MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
|
||||
// TODO: Eliminate findFunction from this.
|
||||
: Builder(block->findFunction()->getContext()) {
|
||||
setInsertionPoint(block);
|
||||
setInsertionPoint(block, insertPoint);
|
||||
}
|
||||
|
||||
/// Reset the insertion point to no location. Creating an operation without a
|
||||
|
@ -242,22 +248,22 @@ public:
|
|||
}
|
||||
|
||||
/// Set the insertion point to the specified operation.
|
||||
void setInsertionPoint(OperationStmt *stmt) {
|
||||
void setInsertionPoint(Statement *stmt) {
|
||||
setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
|
||||
}
|
||||
|
||||
/// Set the insertion point to the end of the specified block.
|
||||
void setInsertionPoint(StmtBlock *block) {
|
||||
this->block = block;
|
||||
insertPoint = block->end();
|
||||
}
|
||||
|
||||
/// Set the insertion point at the beginning of the specified block.
|
||||
void setInsertionPointAtStart(StmtBlock *block) {
|
||||
/// Set the insertion point to the start of the specified block.
|
||||
void setInsertionPointToStart(StmtBlock *block) {
|
||||
this->block = block;
|
||||
insertPoint = block->begin();
|
||||
}
|
||||
|
||||
/// Set the insertion point to the end of the specified block.
|
||||
void setInsertionPointToEnd(StmtBlock *block) {
|
||||
this->block = block;
|
||||
insertPoint = block->end();
|
||||
}
|
||||
|
||||
/// Get the current insertion point of the builder.
|
||||
StmtBlock::iterator getInsertionPoint() const { return insertPoint; }
|
||||
|
||||
|
@ -273,8 +279,14 @@ public:
|
|||
return result;
|
||||
}
|
||||
|
||||
Statement *clone(const Statement &stmt) {
|
||||
Statement *cloneStmt = stmt.clone();
|
||||
/// Create a deep copy of the specified statement, remapping any operands that
|
||||
/// use values outside of the statement using the map that is provided (
|
||||
/// leaving them alone if no entry is present). Replaces references to cloned
|
||||
/// sub-statements to the corresponding statement that is copied, and adds
|
||||
/// those mappings to the map.
|
||||
Statement *clone(const Statement &stmt,
|
||||
OperationStmt::OperandMapTy &operandMapping) {
|
||||
Statement *cloneStmt = stmt.clone(operandMapping, getContext());
|
||||
block->getStatements().insert(insertPoint, cloneStmt);
|
||||
return cloneStmt;
|
||||
}
|
||||
|
|
|
@ -49,8 +49,17 @@ public:
|
|||
/// Remove this statement from its block and delete it.
|
||||
void eraseFromBlock();
|
||||
|
||||
/// Clone this statement, the cloning is deep.
|
||||
Statement *clone() const;
|
||||
// This is a verbose type used by the clone method below.
|
||||
using OperandMapTy =
|
||||
DenseMap<const MLValue *, MLValue *, llvm::DenseMapInfo<const MLValue *>,
|
||||
llvm::detail::DenseMapPair<const MLValue *, MLValue *>>;
|
||||
|
||||
/// Create a deep copy of this statement, remapping any operands that use
|
||||
/// values outside of the statement using the map that is provided (leaving
|
||||
/// them alone if no entry is present). Replaces references to cloned
|
||||
/// sub-statements to the corresponding statement that is copied, and adds
|
||||
/// those mappings to the map.
|
||||
Statement *clone(OperandMapTy &operandMapping, MLIRContext *context) const;
|
||||
|
||||
/// Returns the statement block that contains this statement.
|
||||
StmtBlock *getBlock() const { return block; }
|
||||
|
@ -73,9 +82,6 @@ public:
|
|||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
/// Replace all uses of 'oldVal' with 'newVal' in 'stmt'.
|
||||
void replaceUses(MLValue *oldVal, MLValue *newVal);
|
||||
|
||||
protected:
|
||||
Statement(Kind kind) : kind(kind) {}
|
||||
// Statements are deleted through the destroy() member because this class
|
||||
|
|
|
@ -48,8 +48,6 @@ public:
|
|||
/// Return the context this operation is associated with.
|
||||
MLIRContext *getContext() const;
|
||||
|
||||
OperationStmt *clone() const;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operands
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -210,8 +208,8 @@ public:
|
|||
clear();
|
||||
}
|
||||
|
||||
/// Deep clone this for stmt.
|
||||
ForStmt *clone() const;
|
||||
/// Resolve base class ambiguity.
|
||||
using Statement::findFunction;
|
||||
|
||||
AffineConstantExpr *getLowerBound() const { return lowerBound; }
|
||||
AffineConstantExpr *getUpperBound() const { return upperBound; }
|
||||
|
@ -274,14 +272,14 @@ public:
|
|||
|
||||
~IfStmt();
|
||||
|
||||
/// Deep clone this IfStmt.
|
||||
IfStmt *clone() const;
|
||||
|
||||
IfClause *getThenClause() const { return thenClause; }
|
||||
IfClause *getElseClause() const { return elseClause; }
|
||||
IfClause *getThen() const { return thenClause; }
|
||||
IfClause *getElse() const { return elseClause; }
|
||||
IntegerSet *getCondition() const { return condition; }
|
||||
bool hasElseClause() const { return elseClause != nullptr; }
|
||||
IfClause *createElseClause() { return (elseClause = new IfClause(this)); }
|
||||
bool hasElse() const { return elseClause != nullptr; }
|
||||
IfClause *createElse() {
|
||||
assert(elseClause == nullptr && "already has an else clause!");
|
||||
return (elseClause = new IfClause(this));
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Statement *stmt) {
|
||||
|
@ -294,7 +292,6 @@ private:
|
|||
// The integer set capturing the conditional guard.
|
||||
IntegerSet *condition;
|
||||
// TODO: arguments to integer set
|
||||
ArrayRef<MLValue *> conditionArgs;
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -102,8 +102,9 @@ public:
|
|||
return &StmtBlock::statements;
|
||||
}
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
/// These have unconventional names to avoid derive class ambiguities.
|
||||
void printBlock(raw_ostream &os) const;
|
||||
void dumpBlock() const;
|
||||
|
||||
protected:
|
||||
StmtBlock(StmtBlockKind kind) : kind(kind) {}
|
||||
|
|
|
@ -156,15 +156,13 @@ public:
|
|||
|
||||
void walkIfStmt(IfStmt *ifStmt) {
|
||||
static_cast<SubClass *>(this)->visitIfStmt(ifStmt);
|
||||
walk(ifStmt->getThenClause()->begin(), ifStmt->getThenClause()->end());
|
||||
walk(ifStmt->getElseClause()->begin(), ifStmt->getElseClause()->end());
|
||||
walk(ifStmt->getThen()->begin(), ifStmt->getThen()->end());
|
||||
walk(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
|
||||
}
|
||||
|
||||
void walkIfStmtPostOrder(IfStmt *ifStmt) {
|
||||
walkPostOrder(ifStmt->getThenClause()->begin(),
|
||||
ifStmt->getThenClause()->end());
|
||||
walkPostOrder(ifStmt->getElseClause()->begin(),
|
||||
ifStmt->getElseClause()->end());
|
||||
walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end());
|
||||
walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
|
||||
static_cast<SubClass *>(this)->visitIfStmt(ifStmt);
|
||||
}
|
||||
|
||||
|
|
|
@ -174,10 +174,10 @@ void ModuleState::visitCFGFunction(const CFGFunction *fn) {
|
|||
|
||||
void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
|
||||
recordIntegerSetReference(ifStmt->getCondition());
|
||||
for (auto &childStmt : *ifStmt->getThenClause())
|
||||
for (auto &childStmt : *ifStmt->getThen())
|
||||
visitStatement(&childStmt);
|
||||
if (ifStmt->hasElseClause())
|
||||
for (auto &childStmt : *ifStmt->getElseClause())
|
||||
if (ifStmt->hasElse())
|
||||
for (auto &childStmt : *ifStmt->getElse())
|
||||
visitStatement(&childStmt);
|
||||
}
|
||||
|
||||
|
@ -1270,11 +1270,11 @@ void MLFunctionPrinter::print(const IfStmt *stmt) {
|
|||
os.indent(numSpaces) << "if (";
|
||||
printIntegerSetReference(stmt->getCondition());
|
||||
os << ") {\n";
|
||||
print(stmt->getThenClause());
|
||||
print(stmt->getThen());
|
||||
os.indent(numSpaces) << "}";
|
||||
if (stmt->hasElseClause()) {
|
||||
if (stmt->hasElse()) {
|
||||
os << " else {\n";
|
||||
print(stmt->getElseClause());
|
||||
print(stmt->getElse());
|
||||
os.indent(numSpaces) << "}";
|
||||
}
|
||||
}
|
||||
|
@ -1393,14 +1393,14 @@ void Statement::print(raw_ostream &os) const {
|
|||
|
||||
void Statement::dump() const { print(llvm::errs()); }
|
||||
|
||||
void StmtBlock::print(raw_ostream &os) const {
|
||||
void StmtBlock::printBlock(raw_ostream &os) const {
|
||||
MLFunction *function = findFunction();
|
||||
ModuleState state(function->getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
MLFunctionPrinter(function, modulePrinter).print(this);
|
||||
}
|
||||
|
||||
void StmtBlock::dump() const { print(llvm::errs()); }
|
||||
void StmtBlock::dumpBlock() const { printBlock(llvm::errs()); }
|
||||
|
||||
void Function::print(raw_ostream &os) const {
|
||||
ModuleState state(getContext());
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -77,43 +78,6 @@ bool Statement::isInnermost() const {
|
|||
return nlc.numNestedLoops == 1;
|
||||
}
|
||||
|
||||
Statement *Statement::clone() const {
|
||||
switch (kind) {
|
||||
case Kind::Operation:
|
||||
return cast<OperationStmt>(this)->clone();
|
||||
case Kind::If:
|
||||
llvm_unreachable("cloning for if's not implemented yet");
|
||||
return cast<IfStmt>(this)->clone();
|
||||
case Kind::For:
|
||||
return cast<ForStmt>(this)->clone();
|
||||
}
|
||||
}
|
||||
|
||||
/// Replaces all uses of oldVal with newVal.
|
||||
// TODO(bondhugula,clattner): do this more efficiently by walking those uses of
|
||||
// oldVal that fall within this statement.
|
||||
void Statement::replaceUses(MLValue *oldVal, MLValue *newVal) {
|
||||
struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
|
||||
// Value to be replaced.
|
||||
MLValue *oldVal;
|
||||
// Value to be replaced with.
|
||||
MLValue *newVal;
|
||||
|
||||
ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
|
||||
: oldVal(oldVal), newVal(newVal){};
|
||||
|
||||
void visitOperationStmt(OperationStmt *os) {
|
||||
for (auto &operand : os->getStmtOperands()) {
|
||||
if (operand.get() == oldVal)
|
||||
operand.set(newVal);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ReplaceUseWalker ri(oldVal, newVal);
|
||||
ri.walk(this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ilist_traits for Statement
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -193,22 +157,6 @@ OperationStmt *OperationStmt::create(Identifier name,
|
|||
return stmt;
|
||||
}
|
||||
|
||||
/// Clone an existing OperationStmt.
|
||||
OperationStmt *OperationStmt::clone() const {
|
||||
SmallVector<MLValue *, 8> operands;
|
||||
SmallVector<Type *, 8> resultTypes;
|
||||
|
||||
// TODO(clattner): switch this to iterator logic.
|
||||
// Put together operands and results.
|
||||
for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
|
||||
operands.push_back(getStmtOperand(i).get());
|
||||
|
||||
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
|
||||
resultTypes.push_back(getStmtResult(i).getType());
|
||||
|
||||
return create(getName(), operands, resultTypes, getAttrs(), getContext());
|
||||
}
|
||||
|
||||
OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
|
||||
unsigned numResults,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
|
@ -256,37 +204,6 @@ ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
|
|||
StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
|
||||
upperBound(upperBound), step(step) {}
|
||||
|
||||
ForStmt *ForStmt::clone() const {
|
||||
auto *forStmt = new ForStmt(getLowerBound(), getUpperBound(), getStep(),
|
||||
Statement::findFunction()->getContext());
|
||||
|
||||
// Pairs of <old op stmt result whose uses need to be replaced,
|
||||
// new result generated by the corresponding cloned op stmt>.
|
||||
SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
|
||||
for (auto &s : getStatements()) {
|
||||
auto *cloneStmt = s.clone();
|
||||
forStmt->getStatements().push_back(cloneStmt);
|
||||
if (auto *opStmt = dyn_cast<OperationStmt>(&s)) {
|
||||
auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
|
||||
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
|
||||
oldNewResultPairs.push_back(
|
||||
std::make_pair(const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
|
||||
&cloneOpStmt->getStmtResult(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Replace uses of old op results' with the newly created ones.
|
||||
for (unsigned i = 0, e = oldNewResultPairs.size(); i < e; i++) {
|
||||
for (auto &stmt : *forStmt) {
|
||||
stmt.replaceUses(oldNewResultPairs[i].first, oldNewResultPairs[i].second);
|
||||
}
|
||||
}
|
||||
|
||||
// Replace uses of old loop IV with the new one.
|
||||
forStmt->Statement::replaceUses(const_cast<ForStmt *>(this), forStmt);
|
||||
return forStmt;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -299,7 +216,72 @@ IfStmt::~IfStmt() {
|
|||
// allocated through MLIRContext's bump pointer allocator.
|
||||
}
|
||||
|
||||
IfStmt *IfStmt::clone() const {
|
||||
llvm_unreachable("cloning for if's not implemented yet");
|
||||
return nullptr;
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Statement Cloning
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Create a deep copy of this statement, remapping any operands that use
|
||||
/// values outside of the statement using the map that is provided (leaving
|
||||
/// them alone if no entry is present). Replaces references to cloned
|
||||
/// sub-statements to the corresponding statement that is copied, and adds
|
||||
/// those mappings to the map.
|
||||
Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
|
||||
MLIRContext *context) const {
|
||||
// If the specified value is in operandMap, return the remapped value.
|
||||
// Otherwise return the value itself.
|
||||
auto remapOperand = [&](const MLValue *value) -> MLValue * {
|
||||
auto it = operandMap.find(value);
|
||||
return it != operandMap.end() ? it->second : const_cast<MLValue *>(value);
|
||||
};
|
||||
|
||||
if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
|
||||
SmallVector<MLValue *, 8> operands;
|
||||
operands.reserve(opStmt->getNumOperands());
|
||||
for (auto *opValue : opStmt->getOperands())
|
||||
operands.push_back(remapOperand(opValue));
|
||||
|
||||
SmallVector<Type *, 8> resultTypes;
|
||||
resultTypes.reserve(opStmt->getNumResults());
|
||||
for (auto *result : opStmt->getResults())
|
||||
resultTypes.push_back(result->getType());
|
||||
auto *newOp = OperationStmt::create(
|
||||
opStmt->getName(), operands, resultTypes, opStmt->getAttrs(), context);
|
||||
// Remember the mapping of any results.
|
||||
for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
|
||||
operandMap[opStmt->getResult(i)] = newOp->getResult(i);
|
||||
return newOp;
|
||||
}
|
||||
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(this)) {
|
||||
auto *newFor =
|
||||
new ForStmt(forStmt->getLowerBound(), forStmt->getUpperBound(),
|
||||
forStmt->getStep(), context);
|
||||
// Remember the induction variable mapping.
|
||||
operandMap[forStmt] = newFor;
|
||||
|
||||
// TODO: remap operands in loop bounds when they are added.
|
||||
// Recursively clone the body of the for loop.
|
||||
for (auto &subStmt : *forStmt)
|
||||
newFor->push_back(subStmt.clone(operandMap, context));
|
||||
|
||||
return newFor;
|
||||
}
|
||||
|
||||
// Otherwise, we must have an If statement.
|
||||
auto *ifStmt = cast<IfStmt>(this);
|
||||
auto *newIf = new IfStmt(ifStmt->getCondition());
|
||||
|
||||
// TODO: remap operands with remapOperand when if statements have them.
|
||||
|
||||
auto *resultThen = newIf->getThen();
|
||||
for (auto &childStmt : *ifStmt->getThen())
|
||||
resultThen->push_back(childStmt.clone(operandMap, context));
|
||||
|
||||
if (ifStmt->hasElse()) {
|
||||
auto *resultElse = newIf->createElse();
|
||||
for (auto &childStmt : *ifStmt->getElse())
|
||||
resultElse->push_back(childStmt.clone(operandMap, context));
|
||||
}
|
||||
|
||||
return newIf;
|
||||
}
|
||||
|
|
|
@ -371,10 +371,10 @@ bool MLFuncVerifier::verifyDominance() {
|
|||
|
||||
// If this is an if or for, recursively walk the block they contain.
|
||||
if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
|
||||
if (walkBlock(*ifStmt->getThenClause()))
|
||||
if (walkBlock(*ifStmt->getThen()))
|
||||
return true;
|
||||
|
||||
if (auto *elseClause = ifStmt->getElseClause())
|
||||
if (auto *elseClause = ifStmt->getElse())
|
||||
if (walkBlock(*elseClause))
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -2094,7 +2094,7 @@ class MLFunctionParser : public FunctionParser {
|
|||
public:
|
||||
MLFunctionParser(ParserState &state, MLFunction *function)
|
||||
: FunctionParser(state, Kind::MLFunc), function(function),
|
||||
builder(function) {}
|
||||
builder(function, function->end()) {}
|
||||
|
||||
ParseResult parseFunctionBody();
|
||||
|
||||
|
@ -2191,7 +2191,7 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
return ParseFailure;
|
||||
|
||||
// Reset insertion point to the current block.
|
||||
builder.setInsertionPoint(forStmt->getBlock());
|
||||
builder.setInsertionPointToEnd(forStmt->getBlock());
|
||||
|
||||
// TODO: remove definition of the induction variable.
|
||||
|
||||
|
@ -2348,7 +2348,7 @@ ParseResult MLFunctionParser::parseIfStmt() {
|
|||
return ParseFailure;
|
||||
|
||||
IfStmt *ifStmt = builder.createIf(condition);
|
||||
IfClause *thenClause = ifStmt->getThenClause();
|
||||
IfClause *thenClause = ifStmt->getThen();
|
||||
|
||||
// When parsing of an if statement body fails, the IR contains
|
||||
// the if statement with the portion of the body that has been
|
||||
|
@ -2357,20 +2357,20 @@ ParseResult MLFunctionParser::parseIfStmt() {
|
|||
return ParseFailure;
|
||||
|
||||
if (consumeIf(Token::kw_else)) {
|
||||
auto *elseClause = ifStmt->createElseClause();
|
||||
auto *elseClause = ifStmt->createElse();
|
||||
if (parseElseClause(elseClause))
|
||||
return ParseFailure;
|
||||
}
|
||||
|
||||
// Reset insertion point to the current block.
|
||||
builder.setInsertionPoint(ifStmt->getBlock());
|
||||
builder.setInsertionPointToEnd(ifStmt->getBlock());
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) {
|
||||
if (getToken().is(Token::kw_if)) {
|
||||
builder.setInsertionPoint(elseClause);
|
||||
builder.setInsertionPointToEnd(elseClause);
|
||||
return parseIfStmt();
|
||||
}
|
||||
|
||||
|
@ -2385,7 +2385,7 @@ ParseResult MLFunctionParser::parseStatements(StmtBlock *block) {
|
|||
return builder.createOperation(state);
|
||||
};
|
||||
|
||||
builder.setInsertionPoint(block);
|
||||
builder.setInsertionPointToEnd(block);
|
||||
|
||||
while (getToken().isNot(Token::kw_return, Token::r_brace)) {
|
||||
switch (getToken().getKind()) {
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -77,14 +78,15 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
|
|||
bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
|
||||
if (!hasInnerLoops)
|
||||
loops.push_back(forStmt);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool walkIfStmtPostOrder(IfStmt *ifStmt) {
|
||||
bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
|
||||
ifStmt->getThenClause()->end());
|
||||
hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
|
||||
ifStmt->getElseClause()->end());
|
||||
bool hasInnerLoops =
|
||||
walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end());
|
||||
hasInnerLoops |=
|
||||
walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
|
||||
return hasInnerLoops;
|
||||
}
|
||||
|
||||
|
@ -130,106 +132,35 @@ void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
|
|||
runOnForStmt(forStmt);
|
||||
}
|
||||
|
||||
/// Replace all uses of oldVal with newVal from begin to end.
|
||||
static void replaceUses(StmtBlock::iterator begin, StmtBlock::iterator end,
|
||||
MLValue *oldVal, MLValue *newVal) {
|
||||
// TODO(bondhugula,clattner): do this more efficiently by walking those uses
|
||||
// of oldVal that fall within this list of statements (instead of iterating
|
||||
// through all statements / through all operands of operations found).
|
||||
for (auto it = begin; it != end; it++) {
|
||||
it->replaceUses(oldVal, newVal);
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace all uses of oldVal with newVal.
|
||||
void replaceUses(StmtBlock *block, MLValue *oldVal, MLValue *newVal) {
|
||||
// TODO(bondhugula,clattner): do this more efficiently by walking those uses
|
||||
// of oldVal that fall within this StmtBlock (instead of iterating through
|
||||
// all statements / through all operands of operations found).
|
||||
for (auto it = block->begin(); it != block->end(); it++) {
|
||||
it->replaceUses(oldVal, newVal);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clone the list of stmt's from 'block' and insert into the current
|
||||
/// position of the builder.
|
||||
// TODO(bondhugula,clattner): replace this with a parameterizable clone.
|
||||
void cloneStmtListFromBlock(MLFuncBuilder *builder, const StmtBlock &block) {
|
||||
// Pairs of <old op stmt result whose uses need to be replaced,
|
||||
// new result generated by the corresponding cloned op stmt>.
|
||||
SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
|
||||
|
||||
// Iterator pointing to just before 'this' (i^th) unrolled iteration.
|
||||
StmtBlock::iterator beforeUnrolledBody = --builder->getInsertionPoint();
|
||||
|
||||
for (auto &stmt : block.getStatements()) {
|
||||
auto *cloneStmt = builder->clone(stmt);
|
||||
// Whenever we have an op stmt, we'll have a new ML Value defined: replace
|
||||
// uses of the old result with this one.
|
||||
if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
|
||||
if (opStmt->getNumResults()) {
|
||||
auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
|
||||
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
|
||||
// Store old/new result pairs.
|
||||
// TODO(bondhugula) *only* if needed later: storing of old/new
|
||||
// results can be avoided by cloning the statement list in the
|
||||
// reverse direction (and running the IR builder in the reverse
|
||||
// (iplist.insertAfter()). That way, a newly created result can be
|
||||
// immediately propagated to all its uses.
|
||||
oldNewResultPairs.push_back(std::make_pair(
|
||||
const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
|
||||
&cloneOpStmt->getStmtResult(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Replace uses of old op results' with the new results.
|
||||
StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
|
||||
StmtBlock::iterator endOfUnrolledBody = builder->getInsertionPoint();
|
||||
|
||||
// Replace uses of old op results' with the newly created ones.
|
||||
for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
|
||||
replaceUses(startOfUnrolledBody, endOfUnrolledBody,
|
||||
oldNewResultPairs[i].first, oldNewResultPairs[i].second);
|
||||
}
|
||||
}
|
||||
|
||||
/// Unroll this 'for stmt' / loop completely.
|
||||
/// Unroll this For loop completely.
|
||||
void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
|
||||
auto lb = forStmt->getLowerBound()->getValue();
|
||||
auto ub = forStmt->getUpperBound()->getValue();
|
||||
auto step = forStmt->getStep()->getValue();
|
||||
|
||||
// Builder to add constants need for the unrolled iterator.
|
||||
auto *mlFunc = forStmt->Statement::findFunction();
|
||||
MLFuncBuilder funcTopBuilder(mlFunc);
|
||||
funcTopBuilder.setInsertionPointAtStart(mlFunc);
|
||||
auto *mlFunc = forStmt->findFunction();
|
||||
MLFuncBuilder funcTopBuilder(&mlFunc->front());
|
||||
|
||||
// Builder to insert the unrolled bodies.
|
||||
MLFuncBuilder builder(forStmt->getBlock());
|
||||
// Set insertion point to right after where the for stmt ends.
|
||||
builder.setInsertionPoint(forStmt->getBlock(),
|
||||
++StmtBlock::iterator(forStmt));
|
||||
// Builder to insert the unrolled bodies. We insert right after the
|
||||
/// ForStmt we're unrolling.
|
||||
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
|
||||
|
||||
// Unroll the contents of 'forStmt'.
|
||||
for (int64_t i = lb; i <= ub; i += step) {
|
||||
MLValue *ivConst = nullptr;
|
||||
DenseMap<const MLValue *, MLValue *> operandMapping;
|
||||
|
||||
// If the induction variable is used, create a constant for this unrolled
|
||||
// value and add an operand mapping for it.
|
||||
if (!forStmt->use_empty()) {
|
||||
auto constOp = funcTopBuilder.create<ConstantAffineIntOp>(i);
|
||||
ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
|
||||
auto *ivConst =
|
||||
funcTopBuilder.create<ConstantAffineIntOp>(i)->getResult();
|
||||
operandMapping[forStmt] = cast<MLValue>(ivConst);
|
||||
}
|
||||
StmtBlock::iterator beforeUnrolledBody = --builder.getInsertionPoint();
|
||||
|
||||
// Clone the loop body and insert it right after the loop - the latter will
|
||||
// be erased after all unrolling has been done.
|
||||
cloneStmtListFromBlock(&builder, *forStmt);
|
||||
|
||||
// Replace unrolled loop IV with the unrolled constant.
|
||||
if (ivConst) {
|
||||
StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
|
||||
StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
|
||||
replaceUses(startOfUnrolledBody, endOfUnrolledBody, forStmt, ivConst);
|
||||
// Clone the body of the loop.
|
||||
for (auto &childStmt : *forStmt) {
|
||||
(void)builder.clone(childStmt, operandMapping);
|
||||
}
|
||||
}
|
||||
// Erase the original 'for' stmt from the block.
|
||||
|
|
Loading…
Reference in New Issue