forked from OSchip/llvm-project
Clean up and extend MLFuncBuilder to allow creating statements in the middle of a statement block. Rename Statement::getFunction() and StmtBlock()::getFunction() to findFunction() to make it clear that this is not a constant time getter.
Fix b/112039912 - we were recording 'i' instead of '%i' for loop induction variables causing "use of undefined SSA value" error. PiperOrigin-RevId: 206884644
This commit is contained in:
parent
5228ec3146
commit
8189a12bce
|
@ -195,9 +195,11 @@ private:
|
||||||
/// statement block.
|
/// statement block.
|
||||||
class MLFuncBuilder : public Builder {
|
class MLFuncBuilder : public Builder {
|
||||||
public:
|
public:
|
||||||
MLFuncBuilder(MLFunction *function) : Builder(function->getContext()) {}
|
/// Create ML function builder and set insertion point to the given
|
||||||
|
/// statement block, that is, given ML function, for statement or if statement
|
||||||
MLFuncBuilder(StmtBlock *block) : MLFuncBuilder(block->getFunction()) {
|
/// clause.
|
||||||
|
MLFuncBuilder(StmtBlock *block)
|
||||||
|
: Builder(block->findFunction()->getContext()) {
|
||||||
setInsertionPoint(block);
|
setInsertionPoint(block);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,6 +211,20 @@ public:
|
||||||
insertPoint = StmtBlock::iterator();
|
insertPoint = StmtBlock::iterator();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the insertion point to the specified location.
|
||||||
|
/// Unlike CFGFuncBuilder, MLFuncBuilder allows to set insertion
|
||||||
|
/// point to a different function.
|
||||||
|
void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) {
|
||||||
|
// TODO: check that insertPoint is in this rather than some other block.
|
||||||
|
this->block = block;
|
||||||
|
this->insertPoint = insertPoint;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the insertion point to the specified operation.
|
||||||
|
void setInsertionPoint(OperationStmt *stmt) {
|
||||||
|
setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the insertion point to the end of the specified block.
|
/// Set the insertion point to the end of the specified block.
|
||||||
void setInsertionPoint(StmtBlock *block) {
|
void setInsertionPoint(StmtBlock *block) {
|
||||||
this->block = block;
|
this->block = block;
|
||||||
|
@ -230,8 +246,8 @@ public:
|
||||||
AffineConstantExpr *step = nullptr);
|
AffineConstantExpr *step = nullptr);
|
||||||
|
|
||||||
IfStmt *createIf() {
|
IfStmt *createIf() {
|
||||||
auto stmt = new IfStmt();
|
auto *stmt = new IfStmt();
|
||||||
block->getStatements().push_back(stmt);
|
block->getStatements().insert(insertPoint, stmt);
|
||||||
return stmt;
|
return stmt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,8 +51,13 @@ public:
|
||||||
/// Returns the statement block that contains this statement.
|
/// Returns the statement block that contains this statement.
|
||||||
StmtBlock *getBlock() const { return block; }
|
StmtBlock *getBlock() const { return block; }
|
||||||
|
|
||||||
|
/// Returns the closest surrounding statement that contains this statement
|
||||||
|
/// or nullptr if this is a top-level statement.
|
||||||
|
Statement *getParentStmt() const;
|
||||||
|
|
||||||
/// Returns the function that this statement is part of.
|
/// Returns the function that this statement is part of.
|
||||||
MLFunction *getFunction() const;
|
/// The function is determined by traversing the chain of parent statements.
|
||||||
|
MLFunction *findFunction() const;
|
||||||
|
|
||||||
/// Returns true if there are no more loops nested under this stmt.
|
/// Returns true if there are no more loops nested under this stmt.
|
||||||
bool isInnermost() const;
|
bool isInnermost() const;
|
||||||
|
|
|
@ -45,7 +45,8 @@ public:
|
||||||
Statement *getParentStmt() const;
|
Statement *getParentStmt() const;
|
||||||
|
|
||||||
/// Returns the function that this statement block is part of.
|
/// Returns the function that this statement block is part of.
|
||||||
MLFunction *getFunction() const;
|
/// The function is determined by traversing the chain of parent statements.
|
||||||
|
MLFunction *findFunction() const;
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Statement list management
|
// Statement list management
|
||||||
|
|
|
@ -1065,9 +1065,10 @@ void BasicBlock::print(raw_ostream &os) const {
|
||||||
void BasicBlock::dump() const { print(llvm::errs()); }
|
void BasicBlock::dump() const { print(llvm::errs()); }
|
||||||
|
|
||||||
void Statement::print(raw_ostream &os) const {
|
void Statement::print(raw_ostream &os) const {
|
||||||
ModuleState state(getFunction()->getContext());
|
MLFunction *function = findFunction();
|
||||||
|
ModuleState state(function->getContext());
|
||||||
ModulePrinter modulePrinter(os, state);
|
ModulePrinter modulePrinter(os, state);
|
||||||
MLFunctionPrinter(getFunction(), modulePrinter).print(this);
|
MLFunctionPrinter(function, modulePrinter).print(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Statement::dump() const { print(llvm::errs()); }
|
void Statement::dump() const { print(llvm::errs()); }
|
||||||
|
|
|
@ -162,6 +162,6 @@ ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
|
||||||
if (!step)
|
if (!step)
|
||||||
step = getConstantExpr(1);
|
step = getConstantExpr(1);
|
||||||
auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
|
auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
|
||||||
block->getStatements().push_back(stmt);
|
block->getStatements().insert(insertPoint, stmt);
|
||||||
return stmt;
|
return stmt;
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,8 +57,10 @@ void Statement::destroy() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MLFunction *Statement::getFunction() const {
|
Statement *Statement::getParentStmt() const { return block->getParentStmt(); }
|
||||||
return this->getBlock()->getFunction();
|
|
||||||
|
MLFunction *Statement::findFunction() const {
|
||||||
|
return this->getBlock()->findFunction();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Statement::isInnermost() const {
|
bool Statement::isInnermost() const {
|
||||||
|
|
|
@ -35,7 +35,7 @@ Statement *StmtBlock::getParentStmt() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MLFunction *StmtBlock::getFunction() const {
|
MLFunction *StmtBlock::findFunction() const {
|
||||||
StmtBlock *block = const_cast<StmtBlock *>(this);
|
StmtBlock *block = const_cast<StmtBlock *>(this);
|
||||||
|
|
||||||
while (block->getParentStmt() != nullptr)
|
while (block->getParentStmt() != nullptr)
|
||||||
|
|
|
@ -1335,14 +1335,14 @@ SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) {
|
||||||
return (emitError(useInfo.loc, "reference to invalid result number"),
|
return (emitError(useInfo.loc, "reference to invalid result number"),
|
||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
|
// Otherwise, this is a forward reference. If we are in ML function return
|
||||||
|
// an error. In CFG function, create a placeholder and remember
|
||||||
|
// that we did so.
|
||||||
if (getKind() == Kind::MLFunc)
|
if (getKind() == Kind::MLFunc)
|
||||||
return (
|
return (
|
||||||
emitError(useInfo.loc, "use of undefined SSA value " + useInfo.name),
|
emitError(useInfo.loc, "use of undefined SSA value " + useInfo.name),
|
||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
// Otherwise, this is a forward reference. If we are in ML function return
|
|
||||||
// an error. In CFG function, create a placeholder and remember
|
|
||||||
// that we did so.
|
|
||||||
auto *result = createForwardReferencePlaceholder(useInfo.loc, type);
|
auto *result = createForwardReferencePlaceholder(useInfo.loc, type);
|
||||||
entries[useInfo.number].first = result;
|
entries[useInfo.number].first = result;
|
||||||
entries[useInfo.number].second = useInfo.loc;
|
entries[useInfo.number].second = useInfo.loc;
|
||||||
|
@ -2102,7 +2102,7 @@ ParseResult MLFunctionParser::parseForStmt() {
|
||||||
return emitError("expected SSA identifier for the loop variable");
|
return emitError("expected SSA identifier for the loop variable");
|
||||||
|
|
||||||
auto loc = getToken().getLoc();
|
auto loc = getToken().getLoc();
|
||||||
StringRef inductionVariableName = getTokenSpelling().drop_front();
|
StringRef inductionVariableName = getTokenSpelling();
|
||||||
consumeToken(Token::percent_identifier);
|
consumeToken(Token::percent_identifier);
|
||||||
|
|
||||||
if (parseToken(Token::equal, "expected ="))
|
if (parseToken(Token::equal, "expected ="))
|
||||||
|
@ -2143,6 +2143,8 @@ ParseResult MLFunctionParser::parseForStmt() {
|
||||||
// Reset insertion point to the current block.
|
// Reset insertion point to the current block.
|
||||||
builder.setInsertionPoint(forStmt->getBlock());
|
builder.setInsertionPoint(forStmt->getBlock());
|
||||||
|
|
||||||
|
// TODO: remove definition of the induction variable.
|
||||||
|
|
||||||
return ParseSuccess;
|
return ParseSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -101,9 +101,7 @@ void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
|
||||||
auto trip_count = (ub - lb + 1) / step;
|
auto trip_count = (ub - lb + 1) / step;
|
||||||
|
|
||||||
auto *block = forStmt->getBlock();
|
auto *block = forStmt->getBlock();
|
||||||
|
MLFuncBuilder builder(block);
|
||||||
MLFuncBuilder builder(forStmt->Statement::getFunction());
|
|
||||||
builder.setInsertionPoint(block);
|
|
||||||
|
|
||||||
for (int i = 0; i < trip_count; i++) {
|
for (int i = 0; i < trip_count; i++) {
|
||||||
for (auto &stmt : forStmt->getStatements()) {
|
for (auto &stmt : forStmt->getStatements()) {
|
||||||
|
|
|
@ -285,7 +285,7 @@ mlfunc @undef() {
|
||||||
|
|
||||||
mlfunc @duplicate_induction_var() {
|
mlfunc @duplicate_induction_var() {
|
||||||
for %i = 1 to 10 { // expected-error {{previously defined here}}
|
for %i = 1 to 10 { // expected-error {{previously defined here}}
|
||||||
for %i = 1 to 10 { // expected-error {{redefinition of SSA value 'i'}}
|
for %i = 1 to 10 { // expected-error {{redefinition of SSA value '%i'}}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -145,7 +145,8 @@ mlfunc @loops() {
|
||||||
mlfunc @complex_loops() {
|
mlfunc @complex_loops() {
|
||||||
for %i1 = 1 to 100 { // CHECK: for %i0 = 1 to 100 {
|
for %i1 = 1 to 100 { // CHECK: for %i0 = 1 to 100 {
|
||||||
for %j1 = 1 to 100 { // CHECK: for %i1 = 1 to 100 {
|
for %j1 = 1 to 100 { // CHECK: for %i1 = 1 to 100 {
|
||||||
"foo"() : () -> () // CHECK: "foo"() : () -> ()
|
// CHECK: "foo"(%i0, %i1) : (affineint, affineint) -> ()
|
||||||
|
"foo"(%i1, %j1) : (affineint,affineint) -> ()
|
||||||
} // CHECK: }
|
} // CHECK: }
|
||||||
"boo"() : () -> () // CHECK: "boo"() : () -> ()
|
"boo"() : () -> () // CHECK: "boo"() : () -> ()
|
||||||
for %j2 = 1 to 10 { // CHECK: for %i2 = 1 to 10 {
|
for %j2 = 1 to 10 { // CHECK: for %i2 = 1 to 10 {
|
||||||
|
@ -157,6 +158,7 @@ mlfunc @complex_loops() {
|
||||||
return // CHECK: return
|
return // CHECK: return
|
||||||
} // CHECK: }
|
} // CHECK: }
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: mlfunc @ifstmt() {
|
// CHECK-LABEL: mlfunc @ifstmt() {
|
||||||
mlfunc @ifstmt() {
|
mlfunc @ifstmt() {
|
||||||
for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
|
for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
|
||||||
|
|
Loading…
Reference in New Issue