forked from OSchip/llvm-project
Refactor MLFunction to contain a StmtBlock for its body instead of inheriting
from it. This is necessary progress to squaring away the parent relationship that a StmtBlock has with its enclosing if/for/fn, and makes room for functions to have more than one block in the future. This also removes IfClause and ForStmtBody. This is step 5/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 226936541
This commit is contained in:
parent
9a4060d3f5
commit
d613f5ab65
|
@ -294,6 +294,12 @@ public:
|
|||
setInsertionPoint(stmt);
|
||||
}
|
||||
|
||||
MLFuncBuilder(StmtBlock *block)
|
||||
// TODO: Eliminate findFunction from this.
|
||||
: MLFuncBuilder(block->findFunction()) {
|
||||
setInsertionPoint(block, block->end());
|
||||
}
|
||||
|
||||
MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
|
||||
// TODO: Eliminate findFunction from this.
|
||||
: MLFuncBuilder(block->findFunction()) {
|
||||
|
@ -304,7 +310,7 @@ public:
|
|||
/// the function.
|
||||
MLFuncBuilder(MLFunction *func)
|
||||
: Builder(func->getContext()), function(func) {
|
||||
setInsertionPoint(func, func->begin());
|
||||
setInsertionPoint(func->getBody(), func->getBody()->begin());
|
||||
}
|
||||
|
||||
/// Return the function this builder is referring to.
|
||||
|
|
|
@ -36,7 +36,6 @@ template <typename ObjectType, typename ElementType> class ArgumentIterator;
|
|||
// include nested affine for loops, conditionals and operations.
|
||||
class MLFunction final
|
||||
: public Function,
|
||||
public StmtBlock,
|
||||
private llvm::TrailingObjects<MLFunction, MLFuncArgument> {
|
||||
public:
|
||||
/// Creates a new MLFunction with the specific type.
|
||||
|
@ -44,6 +43,9 @@ public:
|
|||
FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
|
||||
StmtBlock *getBody() { return &body; }
|
||||
const StmtBlock *getBody() const { return &body; }
|
||||
|
||||
/// Destroys this statement and its subclass data.
|
||||
void destroy();
|
||||
|
||||
|
@ -98,9 +100,6 @@ public:
|
|||
static bool classof(const Function *func) {
|
||||
return func->getKind() == Function::Kind::MLFunc;
|
||||
}
|
||||
static bool classof(const StmtBlock *block) {
|
||||
return block->getStmtBlockKind() == StmtBlockKind::MLFunc;
|
||||
}
|
||||
|
||||
private:
|
||||
MLFunction(Location location, StringRef name, FunctionType type,
|
||||
|
@ -119,6 +118,8 @@ private:
|
|||
MutableArrayRef<MLFuncArgument> getArgumentsInternal() {
|
||||
return {getTrailingObjects<MLFuncArgument>(), getNumArguments()};
|
||||
}
|
||||
|
||||
StmtBlock body;
|
||||
};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -274,29 +274,6 @@ private:
|
|||
size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; }
|
||||
};
|
||||
|
||||
/// A ForStmtBody represents statements contained within a ForStmt.
|
||||
class ForStmtBody : public StmtBlock {
|
||||
public:
|
||||
explicit ForStmtBody(ForStmt *stmt)
|
||||
: StmtBlock(StmtBlockKind::ForBody), forStmt(stmt) {
|
||||
assert(stmt != nullptr && "ForStmtBody must have non-null parent");
|
||||
}
|
||||
|
||||
~ForStmtBody() {}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast
|
||||
static bool classof(const StmtBlock *block) {
|
||||
return block->getStmtBlockKind() == StmtBlockKind::ForBody;
|
||||
}
|
||||
|
||||
/// Returns the 'for' statement that contains this body.
|
||||
ForStmt *getFor() { return forStmt; }
|
||||
const ForStmt *getFor() const { return forStmt; }
|
||||
|
||||
private:
|
||||
ForStmt *forStmt;
|
||||
};
|
||||
|
||||
/// For statement represents an affine loop nest.
|
||||
class ForStmt : public Statement, public MLValue {
|
||||
public:
|
||||
|
@ -324,10 +301,10 @@ public:
|
|||
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
|
||||
|
||||
/// Get the body of the ForStmt.
|
||||
ForStmtBody *getBody() { return &body; }
|
||||
StmtBlock *getBody() { return &body; }
|
||||
|
||||
/// Get the body of the ForStmt.
|
||||
const ForStmtBody *getBody() const { return &body; }
|
||||
const StmtBlock *getBody() const { return &body; }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Bounds and step
|
||||
|
@ -455,7 +432,7 @@ public:
|
|||
|
||||
private:
|
||||
// The StmtBlock for the body.
|
||||
ForStmtBody body;
|
||||
StmtBlock body;
|
||||
|
||||
// Affine map for the lower bound.
|
||||
AffineMap lbMap;
|
||||
|
@ -525,31 +502,6 @@ private:
|
|||
friend class ForStmt;
|
||||
};
|
||||
|
||||
/// An if clause represents statements contained within a then or an else clause
|
||||
/// of an if statement.
|
||||
class IfClause : public StmtBlock {
|
||||
public:
|
||||
explicit IfClause(IfStmt *stmt)
|
||||
: StmtBlock(StmtBlockKind::IfClause), ifStmt(stmt) {
|
||||
assert(stmt != nullptr && "If clause must have non-null parent");
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast
|
||||
static bool classof(const StmtBlock *block) {
|
||||
return block->getStmtBlockKind() == StmtBlockKind::IfClause;
|
||||
}
|
||||
|
||||
~IfClause() {}
|
||||
|
||||
/// Returns the if statement that contains this clause.
|
||||
const IfStmt *getIf() const { return ifStmt; }
|
||||
|
||||
IfStmt *getIf() { return ifStmt; }
|
||||
|
||||
private:
|
||||
IfStmt *ifStmt;
|
||||
};
|
||||
|
||||
/// If statement restricts execution to a subset of the loop iteration space.
|
||||
class IfStmt : public Statement {
|
||||
public:
|
||||
|
@ -561,15 +513,15 @@ public:
|
|||
// Then, else, condition.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
IfClause *getThen() { return &thenClause; }
|
||||
const IfClause *getThen() const { return &thenClause; }
|
||||
IfClause *getElse() { return elseClause; }
|
||||
const IfClause *getElse() const { return elseClause; }
|
||||
StmtBlock *getThen() { return &thenClause; }
|
||||
const StmtBlock *getThen() const { return &thenClause; }
|
||||
StmtBlock *getElse() { return elseClause; }
|
||||
const StmtBlock *getElse() const { return elseClause; }
|
||||
bool hasElse() const { return elseClause != nullptr; }
|
||||
|
||||
IfClause *createElse() {
|
||||
StmtBlock *createElse() {
|
||||
assert(elseClause == nullptr && "already has an else clause!");
|
||||
return (elseClause = new IfClause(this));
|
||||
return (elseClause = new StmtBlock(this));
|
||||
}
|
||||
|
||||
const AffineCondition getCondition() const;
|
||||
|
@ -634,9 +586,9 @@ public:
|
|||
|
||||
private:
|
||||
// it is always present.
|
||||
IfClause thenClause;
|
||||
StmtBlock thenClause;
|
||||
// 'else' clause of the if statement. 'nullptr' if there is no else clause.
|
||||
IfClause *elseClause;
|
||||
StmtBlock *elseClause;
|
||||
|
||||
// The integer set capturing the conditional guard.
|
||||
IntegerSet set;
|
||||
|
|
|
@ -39,12 +39,8 @@ template <typename BlockType> class StmtSuccessorIterator;
|
|||
/// children of a parent statement in the ML Function.
|
||||
class StmtBlock : public IRObjectWithUseList {
|
||||
public:
|
||||
enum class StmtBlockKind {
|
||||
MLFunc, // MLFunction
|
||||
ForBody, // ForStmtBody
|
||||
IfClause // IfClause
|
||||
};
|
||||
|
||||
explicit StmtBlock(MLFunction *parent);
|
||||
explicit StmtBlock(Statement *parent);
|
||||
~StmtBlock();
|
||||
|
||||
void clear() {
|
||||
|
@ -54,7 +50,9 @@ public:
|
|||
statements.pop_back();
|
||||
}
|
||||
|
||||
StmtBlockKind getStmtBlockKind() const { return kind; }
|
||||
llvm::PointerUnion<MLFunction *, Statement *> getParent() const {
|
||||
return parent;
|
||||
}
|
||||
|
||||
/// Returns the closest surrounding statement that contains this block or
|
||||
/// nullptr if this is a top-level statement block.
|
||||
|
@ -66,7 +64,10 @@ public:
|
|||
|
||||
/// Returns the function that this statement block is part of.
|
||||
/// The function is determined by traversing the chain of parent statements.
|
||||
MLFunction *findFunction() const;
|
||||
MLFunction *findFunction();
|
||||
const MLFunction *findFunction() const {
|
||||
return const_cast<StmtBlock *>(this)->findFunction();
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Block argument management
|
||||
|
@ -224,11 +225,10 @@ public:
|
|||
void printBlock(raw_ostream &os) const;
|
||||
void dumpBlock() const;
|
||||
|
||||
protected:
|
||||
StmtBlock(StmtBlockKind kind) : kind(kind) {}
|
||||
|
||||
private:
|
||||
StmtBlockKind kind;
|
||||
/// This is the parent function/IfStmt/ForStmt that owns this block.
|
||||
llvm::PointerUnion<MLFunction *, Statement *> parent;
|
||||
|
||||
/// This is the list of statements in the block.
|
||||
StmtListType statements;
|
||||
|
||||
|
|
|
@ -132,11 +132,13 @@ public:
|
|||
// Define walkers for MLFunction and all MLFunction statement kinds.
|
||||
void walk(MLFunction *f) {
|
||||
static_cast<SubClass *>(this)->visitMLFunction(f);
|
||||
static_cast<SubClass *>(this)->walk(f->begin(), f->end());
|
||||
static_cast<SubClass *>(this)->walk(f->getBody()->begin(),
|
||||
f->getBody()->end());
|
||||
}
|
||||
|
||||
void walkPostOrder(MLFunction *f) {
|
||||
static_cast<SubClass *>(this)->walkPostOrder(f->begin(), f->end());
|
||||
static_cast<SubClass *>(this)->walkPostOrder(f->getBody()->begin(),
|
||||
f->getBody()->end());
|
||||
static_cast<SubClass *>(this)->visitMLFunction(f);
|
||||
}
|
||||
|
||||
|
|
|
@ -288,8 +288,8 @@ bool MLFuncVerifier::verifyDominance() {
|
|||
HashTable::ScopeTy blockScope(liveValues);
|
||||
|
||||
// The induction variable of a for statement is live within its body.
|
||||
if (auto *forStmtBody = dyn_cast<ForStmtBody>(&block))
|
||||
liveValues.insert(forStmtBody->getFor(), true);
|
||||
if (auto *forStmt = dyn_cast_or_null<ForStmt>(block.getContainingStmt()))
|
||||
liveValues.insert(forStmt, true);
|
||||
|
||||
for (auto &stmt : block) {
|
||||
// Verify that each of the operands are live.
|
||||
|
@ -330,16 +330,16 @@ bool MLFuncVerifier::verifyDominance() {
|
|||
};
|
||||
|
||||
// Check the whole function out.
|
||||
return walkBlock(fn);
|
||||
return walkBlock(*fn.getBody());
|
||||
}
|
||||
|
||||
bool MLFuncVerifier::verifyReturn() {
|
||||
// TODO: fold return verification in the pass that verifies all statements.
|
||||
const char missingReturnMsg[] = "ML function must end with return statement";
|
||||
if (fn.getStatements().empty())
|
||||
if (fn.getBody()->getStatements().empty())
|
||||
return failure(missingReturnMsg, fn);
|
||||
|
||||
const auto &stmt = fn.getStatements().back();
|
||||
const auto &stmt = fn.getBody()->getStatements().back();
|
||||
if (const auto *op = dyn_cast<OperationStmt>(&stmt)) {
|
||||
if (!op->isReturn())
|
||||
return failure(missingReturnMsg, fn);
|
||||
|
|
|
@ -230,7 +230,7 @@ void ModuleState::visitStatement(const Statement *stmt) {
|
|||
|
||||
void ModuleState::visitMLFunction(const MLFunction *fn) {
|
||||
visitType(fn->getType());
|
||||
for (auto &stmt : *fn) {
|
||||
for (auto &stmt : *fn->getBody()) {
|
||||
ModuleState::visitStatement(&stmt);
|
||||
}
|
||||
}
|
||||
|
@ -1390,7 +1390,7 @@ void MLFunctionPrinter::print() {
|
|||
printFunctionSignature();
|
||||
printFunctionAttributes(getFunction());
|
||||
os << " {\n";
|
||||
print(function);
|
||||
print(function->getBody());
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
|
@ -1649,7 +1649,7 @@ void Statement::print(raw_ostream &os) const {
|
|||
void Statement::dump() const { print(llvm::errs()); }
|
||||
|
||||
void StmtBlock::printBlock(raw_ostream &os) const {
|
||||
MLFunction *function = findFunction();
|
||||
const MLFunction *function = findFunction();
|
||||
ModuleState state(function->getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
MLFunctionPrinter(function, modulePrinter).print(this);
|
||||
|
|
|
@ -474,7 +474,7 @@ void ReturnOp::print(OpAsmPrinter *p) const {
|
|||
bool ReturnOp::verify() const {
|
||||
const Function *function;
|
||||
if (auto *stmt = dyn_cast<OperationStmt>(getOperation()))
|
||||
function = cast<MLFunction>(stmt->getBlock());
|
||||
function = stmt->getBlock()->findFunction();
|
||||
else
|
||||
function = cast<Instruction>(getOperation())->getFunction();
|
||||
|
||||
|
|
|
@ -202,14 +202,13 @@ MLFunction *MLFunction::create(Location location, StringRef name,
|
|||
|
||||
MLFunction::MLFunction(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs)
|
||||
: Function(Kind::MLFunc, location, name, type, attrs),
|
||||
StmtBlock(StmtBlockKind::MLFunc) {}
|
||||
: Function(Kind::MLFunc, location, name, type, attrs), body(this) {}
|
||||
|
||||
MLFunction::~MLFunction() {
|
||||
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
|
||||
// since child statements need to be destroyed before function arguments
|
||||
// are destroyed.
|
||||
clear();
|
||||
getBody()->clear();
|
||||
|
||||
// Explicitly run the destructors for the function arguments.
|
||||
for (auto &arg : getArgumentsInternal())
|
||||
|
@ -222,11 +221,11 @@ void MLFunction::destroy() {
|
|||
}
|
||||
|
||||
const OperationStmt *MLFunction::getReturnStmt() const {
|
||||
return cast<OperationStmt>(&back());
|
||||
return cast<OperationStmt>(&getBody()->back());
|
||||
}
|
||||
|
||||
OperationStmt *MLFunction::getReturnStmt() {
|
||||
return cast<OperationStmt>(&back());
|
||||
return cast<OperationStmt>(&getBody()->back());
|
||||
}
|
||||
|
||||
void MLFunction::walk(std::function<void(OperationStmt *)> callback) {
|
||||
|
|
|
@ -581,7 +581,7 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
|
|||
// Verify that the operation is at the end of the respective parent block.
|
||||
if (auto *stmt = dyn_cast<OperationStmt>(op)) {
|
||||
StmtBlock *block = stmt->getBlock();
|
||||
if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
|
||||
if (!block || block->getContainingStmt() || &block->back() != stmt)
|
||||
return op->emitOpError("must be the last statement in the ML function");
|
||||
} else {
|
||||
const Instruction *inst = cast<Instruction>(op);
|
||||
|
|
|
@ -20,33 +20,30 @@
|
|||
#include "mlir/IR/Statements.h"
|
||||
using namespace mlir;
|
||||
|
||||
StmtBlock::StmtBlock(MLFunction *parent) : parent(parent) {}
|
||||
|
||||
StmtBlock::StmtBlock(Statement *parent) : parent(parent) {}
|
||||
|
||||
StmtBlock::~StmtBlock() {
|
||||
clear();
|
||||
|
||||
llvm::DeleteContainerPointers(arguments);
|
||||
}
|
||||
|
||||
/// Returns the closest surrounding statement that contains this block or
|
||||
/// nullptr if this is a top-level statement block.
|
||||
Statement *StmtBlock::getContainingStmt() {
|
||||
switch (kind) {
|
||||
case StmtBlockKind::MLFunc:
|
||||
return nullptr;
|
||||
case StmtBlockKind::ForBody:
|
||||
return cast<ForStmtBody>(this)->getFor();
|
||||
case StmtBlockKind::IfClause:
|
||||
return cast<IfClause>(this)->getIf();
|
||||
}
|
||||
return parent.dyn_cast<Statement *>();
|
||||
}
|
||||
|
||||
MLFunction *StmtBlock::findFunction() const {
|
||||
// FIXME: const incorrect.
|
||||
StmtBlock *block = const_cast<StmtBlock *>(this);
|
||||
|
||||
while (block->getContainingStmt()) {
|
||||
block = block->getContainingStmt()->getBlock();
|
||||
MLFunction *StmtBlock::findFunction() {
|
||||
StmtBlock *block = this;
|
||||
while (auto *stmt = block->getContainingStmt()) {
|
||||
block = stmt->getBlock();
|
||||
if (!block)
|
||||
return nullptr;
|
||||
}
|
||||
return dyn_cast<MLFunction>(block);
|
||||
return block->getParent().get<MLFunction *>();
|
||||
}
|
||||
|
||||
/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor
|
||||
|
|
|
@ -2777,7 +2777,7 @@ class MLFunctionParser : public FunctionParser {
|
|||
public:
|
||||
MLFunctionParser(ParserState &state, MLFunction *function)
|
||||
: FunctionParser(state, Kind::MLFunc), function(function),
|
||||
builder(function, function->end()) {}
|
||||
builder(function->getBody()) {}
|
||||
|
||||
ParseResult parseFunctionBody();
|
||||
|
||||
|
@ -2796,7 +2796,7 @@ private:
|
|||
ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap &map,
|
||||
bool isLower);
|
||||
ParseResult parseIfStmt();
|
||||
ParseResult parseElseClause(IfClause *elseClause);
|
||||
ParseResult parseElseClause(StmtBlock *elseClause);
|
||||
ParseResult parseStatements(StmtBlock *block);
|
||||
ParseResult parseStmtBlock(StmtBlock *block);
|
||||
|
||||
|
@ -2812,7 +2812,7 @@ ParseResult MLFunctionParser::parseFunctionBody() {
|
|||
auto braceLoc = getToken().getLoc();
|
||||
|
||||
// Parse statements in this function.
|
||||
if (parseStmtBlock(function))
|
||||
if (parseStmtBlock(function->getBody()))
|
||||
return ParseFailure;
|
||||
|
||||
return finalizeFunction(function, braceLoc);
|
||||
|
@ -3121,7 +3121,7 @@ ParseResult MLFunctionParser::parseIfStmt() {
|
|||
IfStmt *ifStmt =
|
||||
builder.createIf(getEncodedSourceLocation(loc), operands, set);
|
||||
|
||||
IfClause *thenClause = ifStmt->getThen();
|
||||
StmtBlock *thenClause = ifStmt->getThen();
|
||||
|
||||
// When parsing of an if statement body fails, the IR contains
|
||||
// the if statement with the portion of the body that has been
|
||||
|
@ -3141,7 +3141,7 @@ ParseResult MLFunctionParser::parseIfStmt() {
|
|||
return ParseSuccess;
|
||||
}
|
||||
|
||||
ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) {
|
||||
ParseResult MLFunctionParser::parseElseClause(StmtBlock *elseClause) {
|
||||
if (getToken().is(Token::kw_if)) {
|
||||
builder.setInsertionPointToEnd(elseClause);
|
||||
return parseIfStmt();
|
||||
|
|
|
@ -490,7 +490,7 @@ CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) {
|
|||
}
|
||||
|
||||
// Convert statements in order.
|
||||
for (auto &stmt : *mlFunc) {
|
||||
for (auto &stmt : *mlFunc->getBody()) {
|
||||
visit(&stmt);
|
||||
}
|
||||
|
||||
|
|
|
@ -426,7 +426,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
|
|||
}
|
||||
|
||||
PassResult DmaGeneration::runOnMLFunction(MLFunction *f) {
|
||||
for (auto &stmt : *f) {
|
||||
for (auto &stmt : *f->getBody()) {
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
|
||||
runOnForStmt(forStmt);
|
||||
}
|
||||
|
|
|
@ -348,7 +348,7 @@ public:
|
|||
bool MemRefDependenceGraph::init(MLFunction *f) {
|
||||
unsigned id = 0;
|
||||
DenseMap<MLValue *, SetVector<unsigned>> memrefAccesses;
|
||||
for (auto &stmt : *f) {
|
||||
for (auto &stmt : *f->getBody()) {
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
|
||||
// Create graph node 'id' to represent top-level 'forStmt' and record
|
||||
// all loads and store accesses it contains.
|
||||
|
|
|
@ -230,8 +230,8 @@ static void getTileableBands(MLFunction *f,
|
|||
bands->push_back(band);
|
||||
};
|
||||
|
||||
for (auto &stmt : *f) {
|
||||
ForStmt *forStmt = dyn_cast<ForStmt>(&stmt);
|
||||
for (auto &stmt : *f->getBody()) {
|
||||
auto *forStmt = dyn_cast<ForStmt>(&stmt);
|
||||
if (!forStmt)
|
||||
continue;
|
||||
getMaximalPerfectLoopNest(forStmt);
|
||||
|
|
|
@ -92,10 +92,10 @@ PassResult LoopUnrollAndJam::runOnMLFunction(MLFunction *f) {
|
|||
// Currently, just the outermost loop from the first loop nest is
|
||||
// unroll-and-jammed by this pass. However, runOnForStmt can be called on any
|
||||
// for Stmt.
|
||||
if (!isa<ForStmt>(f->begin()))
|
||||
auto *forStmt = dyn_cast<ForStmt>(f->getBody()->begin());
|
||||
if (!forStmt)
|
||||
return success();
|
||||
|
||||
auto *forStmt = cast<ForStmt>(f->begin());
|
||||
runOnForStmt(forStmt);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -238,7 +238,7 @@ struct LowerVectorTransfersPass
|
|||
makeFuncWiseState(MLFunction *f) const override {
|
||||
auto state = llvm::make_unique<LowerVectorTransfersState>();
|
||||
auto builder = MLFuncBuilder(f);
|
||||
builder.setInsertionPointToStart(f);
|
||||
builder.setInsertionPointToStart(f->getBody());
|
||||
state->zero = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), 0);
|
||||
return state;
|
||||
}
|
||||
|
|
|
@ -177,7 +177,8 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
|||
cast<Instruction>(op)->moveBefore(&entryBB, entryBB.begin());
|
||||
} else {
|
||||
auto *mlFunc = cast<MLFunction>(currentFunction);
|
||||
cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
|
||||
cast<OperationStmt>(op)->moveBefore(mlFunc->getBody(),
|
||||
mlFunc->getBody()->begin());
|
||||
}
|
||||
|
||||
continue;
|
||||
|
|
|
@ -102,7 +102,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
|
|||
if (!forStmt->use_empty()) {
|
||||
if (forStmt->hasConstantLowerBound()) {
|
||||
auto *mlFunc = forStmt->findFunction();
|
||||
MLFuncBuilder topBuilder(&mlFunc->front());
|
||||
MLFuncBuilder topBuilder(&mlFunc->getBody()->front());
|
||||
auto constOp = topBuilder.create<ConstantIndexOp>(
|
||||
forStmt->getLoc(), forStmt->getConstantLowerBound());
|
||||
forStmt->replaceAllUsesWith(constOp);
|
||||
|
|
Loading…
Reference in New Issue