From c8b0273f19ae182caa0f355d89661983ca93470f Mon Sep 17 00:00:00 2001 From: Tatiana Shpeisman Date: Mon, 30 Jul 2018 15:18:10 -0700 Subject: [PATCH] Implement induction variables. Pretty print induction variable operands as %i. Add support for future pretty printing of ML function arguments as %arg. Induction variables are implemented by inheriting ForStmt from MLValue. ForStmt provides APIs that make this design decision invisible to the ForStmt users. This CL in combination with cl/206253643 resolves http://b/111769060. PiperOrigin-RevId: 206655937 --- mlir/include/mlir/IR/Statements.h | 18 ++++++++--- mlir/lib/IR/AsmPrinter.cpp | 51 +++++++++++++++++++++++++------ mlir/lib/IR/Builders.cpp | 2 +- mlir/lib/IR/Statement.cpp | 11 +++++++ mlir/lib/Parser/Parser.cpp | 19 ++++++++---- mlir/test/IR/invalid.mlir | 9 ++++++ mlir/test/IR/parser.mlir | 24 ++++++++++++--- mlir/test/Transforms/unroll.mlir | 2 +- 8 files changed, 109 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index cc31c120c53a..f77f1d7ea3b2 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -187,19 +187,17 @@ private: }; /// For statement represents an affine loop nest. -class ForStmt : public Statement, public StmtBlock { +class ForStmt : public Statement, public StmtBlock, private MLValue { public: // TODO: lower and upper bounds should be affine maps with // dimension and symbol use lists. explicit ForStmt(AffineConstantExpr *lowerBound, - AffineConstantExpr *upperBound, AffineConstantExpr *step) - : Statement(Kind::For), StmtBlock(StmtBlockKind::For), - lowerBound(lowerBound), upperBound(upperBound), step(step) {} + AffineConstantExpr *upperBound, AffineConstantExpr *step, + MLIRContext *context); // Loop bounds and step are immortal objects and don't need to be deleted. ~ForStmt() {} - // TODO: represent induction variable AffineConstantExpr *getLowerBound() const { return lowerBound; } AffineConstantExpr *getUpperBound() const { return upperBound; } AffineConstantExpr *getStep() const { return step; } @@ -213,6 +211,16 @@ public: return block->getStmtBlockKind() == StmtBlockKind::For; } + // For statement represents induction variable by inheriting + // from MLValue. This design is hidden behind interfaces. + static bool classof(const SSAValue *value) { + return value->getKind() == SSAValueKind::InductionVar; + } + + /// MLValue methods + MLValue *getInductionVar() { return this; } + const MLValue *getInductionVar() const { return this; } + private: AffineConstantExpr *lowerBound; AffineConstantExpr *upperBound; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index c928b34d54a7..2e09850d19dc 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -542,22 +542,40 @@ public: protected: void numberValueID(const SSAValue *value) { assert(!valueIDs.count(value) && "Value numbered multiple times"); - valueIDs[value] = nextValueID++; + unsigned id; + switch (value->getKind()) { + case SSAValueKind::BBArgument: + case SSAValueKind::InstResult: + case SSAValueKind::StmtResult: + id = nextValueID++; + break; + case SSAValueKind::FnArgument: + id = nextFnArgumentID++; + break; + case SSAValueKind::InductionVar: + id = nextInductionVarID++; + break; + } + valueIDs[value] = id; } - void printValueID(const SSAValue *value, - bool dontPrintResultNo = false) const { + void printValueID(const SSAValue *value, bool printResultNo = true) const { int resultNo = -1; auto lookupValue = value; - // If this is a reference to the result of a multi-result instruction, print - // out the # identifier and make sure to map our lookup to the first result - // of the instruction. + // If this is a reference to the result of a multi-result instruction or + // statement, print out the # identifier and make sure to map our lookup + // to the first result of the instruction. if (auto *result = dyn_cast(value)) { if (result->getOwner()->getNumResults() != 1) { resultNo = result->getResultNumber(); lookupValue = result->getOwner()->getResult(0); } + } else if (auto *result = dyn_cast(value)) { + if (result->getOwner()->getNumResults() != 1) { + resultNo = result->getResultNumber(); + lookupValue = result->getOwner()->getResult(0); + } } auto it = valueIDs.find(lookupValue); @@ -566,8 +584,14 @@ protected: return; } - os << '%' << it->getSecond(); - if (resultNo != -1 && !dontPrintResultNo) + os << '%'; + if (isa(value)) + + os << 'i'; + else if (isa(value)) + os << "arg"; + os << it->getSecond(); + if (resultNo != -1 && printResultNo) os << '#' << resultNo; } @@ -575,12 +599,14 @@ private: /// This is the value ID for each SSA value in the current function. DenseMap valueIDs; unsigned nextValueID = 0; + unsigned nextInductionVarID = 0; + unsigned nextFnArgumentID = 0; }; } // end anonymous namespace void FunctionPrinter::printOperation(const Operation *op) { if (op->getNumResults()) { - printValueID(op->getResult(0), /*dontPrintResultNo*/ true); + printValueID(op->getResult(0), /*printResultNo=*/false); os << " = "; } @@ -874,6 +900,9 @@ void MLFunctionPrinter::numberValues() { if (stmt->getNumResults() != 0) printer->numberValueID(stmt->getResult(0)); } + void visitForStmt(ForStmt *stmt) { + printer->numberValueID(stmt->getInductionVar()); + } MLFunctionPrinter *printer; }; @@ -918,7 +947,9 @@ void MLFunctionPrinter::print(const OperationStmt *stmt) { } void MLFunctionPrinter::print(const ForStmt *stmt) { - os.indent(numSpaces) << "for x = " << *stmt->getLowerBound(); + os.indent(numSpaces) << "for "; + printOperand(stmt->getInductionVar()); + os << " = " << *stmt->getLowerBound(); os << " to " << *stmt->getUpperBound(); if (stmt->getStep()->getValue() != 1) os << " step " << *stmt->getStep(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index e7cb9cb5ee50..1a094d9c2008 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -161,7 +161,7 @@ ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound, AffineConstantExpr *step) { if (!step) step = getConstantExpr(1); - auto *stmt = new ForStmt(lowerBound, upperBound, step); + auto *stmt = new ForStmt(lowerBound, upperBound, step, context); block->getStatements().push_back(stmt); return stmt; } diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 0731e546c73a..6ace8ffcda29 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/MLFunction.h" #include "mlir/IR/Statements.h" #include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/Types.h" using namespace mlir; //===----------------------------------------------------------------------===// @@ -190,6 +191,16 @@ OperationStmt *SSAValue::getDefiningStmt() { return nullptr; } +//===----------------------------------------------------------------------===// +// ForStmt +//===----------------------------------------------------------------------===// + +ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound, + AffineConstantExpr *step, MLIRContext *context) + : Statement(Kind::For), StmtBlock(StmtBlockKind::For), + MLValue(MLValueKind::InductionVar, Type::getAffineInt(context)), + lowerBound(lowerBound), upperBound(upperBound), step(step) {} + //===----------------------------------------------------------------------===// // IfStmt //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 654db988dfc5..3bcec3777af7 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2086,10 +2086,8 @@ ParseResult MLFunctionParser::parseForStmt() { if (getToken().isNot(Token::percent_identifier)) return emitError("expected SSA identifier for the loop variable"); - // TODO: create SSA value definition from name - StringRef name = getTokenSpelling().drop_front(); - (void)name; - + auto loc = getToken().getLoc(); + StringRef inductionVariableName = getTokenSpelling().drop_front(); consumeToken(Token::percent_identifier); if (parseToken(Token::equal, "expected =")) @@ -2116,14 +2114,20 @@ ParseResult MLFunctionParser::parseForStmt() { } // Create for statement. - ForStmt *stmt = builder.createFor(lowerBound, upperBound, step); + ForStmt *forStmt = builder.createFor(lowerBound, upperBound, step); + + // Create SSA value definition for the induction variable. + addDefinition({inductionVariableName, 0, loc}, forStmt->getInductionVar()); // If parsing of the for statement body fails, // MLIR contains for statement with those nested statements that have been // successfully parsed. - if (parseStmtBlock(static_cast(stmt))) + if (parseStmtBlock(forStmt)) return ParseFailure; + // Reset insertion point to the current block. + builder.setInsertionPoint(forStmt->getBlock()); + return ParseSuccess; } @@ -2174,6 +2178,9 @@ ParseResult MLFunctionParser::parseIfStmt() { return ParseFailure; } + // Reset insertion point to the current block. + builder.setInsertionPoint(ifStmt->getBlock()); + return ParseSuccess; } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 38a79f8d83a0..6d1f15d84c8c 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -281,3 +281,12 @@ mlfunc @undef() { return } +// ----- + +mlfunc @duplicate_induction_var() { + for %i = 1 to 10 { // expected-error {{previously defined here}} + for %i = 1 to 10 { // expected-error {{redefinition of SSA value 'i'}} + } + } + return +} diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 61604fb7f1d2..fe8223832449 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -116,7 +116,7 @@ mlfunc @mlfunc_with_args(%a : f16) { mlfunc @mlfunc_with_ops() { // CHECK: %0 = "foo"() : () -> i64 %a = "foo"() : ()->i64 - // CHECK: for x = 1 to 10 { + // CHECK: for %i0 = 1 to 10 { for %i = 1 to 10 { // CHECK: %1 = "doo"() : () -> f32 %b = "doo"() : ()->f32 @@ -132,18 +132,34 @@ mlfunc @mlfunc_with_ops() { // CHECK-LABEL: mlfunc @loops() { mlfunc @loops() { - // CHECK: for x = 1 to 100 step 2 { + // CHECK: for %i0 = 1 to 100 step 2 { for %i = 1 to 100 step 2 { - // CHECK: for x = 1 to 200 { + // CHECK: for %i1 = 1 to 200 { for %j = 1 to 200 { } // CHECK: } } // CHECK: } return // CHECK: return } // CHECK: } +// CHECK-LABEL: mlfunc @complex_loops() { +mlfunc @complex_loops() { + for %i1 = 1 to 100 { // CHECK: for %i0 = 1 to 100 { + for %j1 = 1 to 100 { // CHECK: for %i1 = 1 to 100 { + "foo"() : () -> () // CHECK: "foo"() : () -> () + } // CHECK: } + "boo"() : () -> () // CHECK: "boo"() : () -> () + for %j2 = 1 to 10 { // CHECK: for %i2 = 1 to 10 { + for %k2 = 1 to 10 { // CHECK: for %i3 = 1 to 10 { + "goo"() : () -> () // CHECK: "goo"() : () -> () + } // CHECK: } + } // CHECK: } + } // CHECK: } + return // CHECK: return +} // CHECK: } + // CHECK-LABEL: mlfunc @ifstmt() { mlfunc @ifstmt() { - for %i = 1 to 10 { // CHECK for x = 1 to 10 { + for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 { if () { // CHECK if () { } else if () { // CHECK } else if () { } else { // CHECK } else { diff --git a/mlir/test/Transforms/unroll.mlir b/mlir/test/Transforms/unroll.mlir index 3c421424b49c..2d42e4a2c432 100644 --- a/mlir/test/Transforms/unroll.mlir +++ b/mlir/test/Transforms/unroll.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: mlfunc @loops() { mlfunc @loops() { - // CHECK: for x = 1 to 100 step 2 { + // CHECK: for %i0 = 1 to 100 step 2 { for %i = 1 to 100 step 2 { // CHECK: "custom"(){value: 1} : () -> () // CHECK-NEXT: "custom"(){value: 1} : () -> ()