diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index b84da8324055..c6a09f16fbcf 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -195,9 +195,11 @@ private: /// statement block. class MLFuncBuilder : public Builder { public: - MLFuncBuilder(MLFunction *function) : Builder(function->getContext()) {} - - MLFuncBuilder(StmtBlock *block) : MLFuncBuilder(block->getFunction()) { + /// Create ML function builder and set insertion point to the given + /// statement block, that is, given ML function, for statement or if statement + /// clause. + MLFuncBuilder(StmtBlock *block) + : Builder(block->findFunction()->getContext()) { setInsertionPoint(block); } @@ -209,6 +211,20 @@ public: insertPoint = StmtBlock::iterator(); } + /// Set the insertion point to the specified location. + /// Unlike CFGFuncBuilder, MLFuncBuilder allows to set insertion + /// point to a different function. + void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) { + // TODO: check that insertPoint is in this rather than some other block. + this->block = block; + this->insertPoint = insertPoint; + } + + /// Set the insertion point to the specified operation. + void setInsertionPoint(OperationStmt *stmt) { + setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt)); + } + /// Set the insertion point to the end of the specified block. void setInsertionPoint(StmtBlock *block) { this->block = block; @@ -230,8 +246,8 @@ public: AffineConstantExpr *step = nullptr); IfStmt *createIf() { - auto stmt = new IfStmt(); - block->getStatements().push_back(stmt); + auto *stmt = new IfStmt(); + block->getStatements().insert(insertPoint, stmt); return stmt; } diff --git a/mlir/include/mlir/IR/Statement.h b/mlir/include/mlir/IR/Statement.h index d2bfdb29c8f1..f43f0b6bd302 100644 --- a/mlir/include/mlir/IR/Statement.h +++ b/mlir/include/mlir/IR/Statement.h @@ -51,8 +51,13 @@ public: /// Returns the statement block that contains this statement. StmtBlock *getBlock() const { return block; } + /// Returns the closest surrounding statement that contains this statement + /// or nullptr if this is a top-level statement. + Statement *getParentStmt() const; + /// Returns the function that this statement is part of. - MLFunction *getFunction() const; + /// The function is determined by traversing the chain of parent statements. + MLFunction *findFunction() const; /// Returns true if there are no more loops nested under this stmt. bool isInnermost() const; diff --git a/mlir/include/mlir/IR/StmtBlock.h b/mlir/include/mlir/IR/StmtBlock.h index 8f4cc2046e1a..8b03bf120318 100644 --- a/mlir/include/mlir/IR/StmtBlock.h +++ b/mlir/include/mlir/IR/StmtBlock.h @@ -45,7 +45,8 @@ public: Statement *getParentStmt() const; /// Returns the function that this statement block is part of. - MLFunction *getFunction() const; + /// The function is determined by traversing the chain of parent statements. + MLFunction *findFunction() const; //===--------------------------------------------------------------------===// // Statement list management diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 99fb1df3a10f..275f8baf2ecf 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1065,9 +1065,10 @@ void BasicBlock::print(raw_ostream &os) const { void BasicBlock::dump() const { print(llvm::errs()); } void Statement::print(raw_ostream &os) const { - ModuleState state(getFunction()->getContext()); + MLFunction *function = findFunction(); + ModuleState state(function->getContext()); ModulePrinter modulePrinter(os, state); - MLFunctionPrinter(getFunction(), modulePrinter).print(this); + MLFunctionPrinter(function, modulePrinter).print(this); } void Statement::dump() const { print(llvm::errs()); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 1a094d9c2008..9ba75e49fa22 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -162,6 +162,6 @@ ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound, if (!step) step = getConstantExpr(1); auto *stmt = new ForStmt(lowerBound, upperBound, step, context); - block->getStatements().push_back(stmt); + block->getStatements().insert(insertPoint, stmt); return stmt; } diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 3ac481f6d7bb..b55b5976572e 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -57,8 +57,10 @@ void Statement::destroy() { } } -MLFunction *Statement::getFunction() const { - return this->getBlock()->getFunction(); +Statement *Statement::getParentStmt() const { return block->getParentStmt(); } + +MLFunction *Statement::findFunction() const { + return this->getBlock()->findFunction(); } bool Statement::isInnermost() const { diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp index 16ddb370a183..21b870f75722 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/StmtBlock.cpp @@ -35,7 +35,7 @@ Statement *StmtBlock::getParentStmt() const { } } -MLFunction *StmtBlock::getFunction() const { +MLFunction *StmtBlock::findFunction() const { StmtBlock *block = const_cast(this); while (block->getParentStmt() != nullptr) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index aa08cb6a98f0..0f4a3a53cfbc 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1335,14 +1335,14 @@ SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) { return (emitError(useInfo.loc, "reference to invalid result number"), nullptr); + // Otherwise, this is a forward reference. If we are in ML function return + // an error. In CFG function, create a placeholder and remember + // that we did so. if (getKind() == Kind::MLFunc) return ( emitError(useInfo.loc, "use of undefined SSA value " + useInfo.name), nullptr); - // Otherwise, this is a forward reference. If we are in ML function return - // an error. In CFG function, create a placeholder and remember - // that we did so. auto *result = createForwardReferencePlaceholder(useInfo.loc, type); entries[useInfo.number].first = result; entries[useInfo.number].second = useInfo.loc; @@ -2102,7 +2102,7 @@ ParseResult MLFunctionParser::parseForStmt() { return emitError("expected SSA identifier for the loop variable"); auto loc = getToken().getLoc(); - StringRef inductionVariableName = getTokenSpelling().drop_front(); + StringRef inductionVariableName = getTokenSpelling(); consumeToken(Token::percent_identifier); if (parseToken(Token::equal, "expected =")) @@ -2143,6 +2143,8 @@ ParseResult MLFunctionParser::parseForStmt() { // Reset insertion point to the current block. builder.setInsertionPoint(forStmt->getBlock()); + // TODO: remove definition of the induction variable. + return ParseSuccess; } diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 489a98f8468b..160a463eb78c 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -101,9 +101,7 @@ void LoopUnroll::runOnForStmt(ForStmt *forStmt) { auto trip_count = (ub - lb + 1) / step; auto *block = forStmt->getBlock(); - - MLFuncBuilder builder(forStmt->Statement::getFunction()); - builder.setInsertionPoint(block); + MLFuncBuilder builder(block); for (int i = 0; i < trip_count; i++) { for (auto &stmt : forStmt->getStatements()) { diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 6d1f15d84c8c..bc88db9d45e2 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -285,7 +285,7 @@ mlfunc @undef() { mlfunc @duplicate_induction_var() { for %i = 1 to 10 { // expected-error {{previously defined here}} - for %i = 1 to 10 { // expected-error {{redefinition of SSA value 'i'}} + for %i = 1 to 10 { // expected-error {{redefinition of SSA value '%i'}} } } return diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index ce101e56ce72..0ed0938347ea 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -145,7 +145,8 @@ mlfunc @loops() { mlfunc @complex_loops() { for %i1 = 1 to 100 { // CHECK: for %i0 = 1 to 100 { for %j1 = 1 to 100 { // CHECK: for %i1 = 1 to 100 { - "foo"() : () -> () // CHECK: "foo"() : () -> () + // CHECK: "foo"(%i0, %i1) : (affineint, affineint) -> () + "foo"(%i1, %j1) : (affineint,affineint) -> () } // CHECK: } "boo"() : () -> () // CHECK: "boo"() : () -> () for %j2 = 1 to 10 { // CHECK: for %i2 = 1 to 10 { @@ -157,6 +158,7 @@ mlfunc @complex_loops() { return // CHECK: return } // CHECK: } + // CHECK-LABEL: mlfunc @ifstmt() { mlfunc @ifstmt() { for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {