Implement induction variables. Pretty print induction variable operands as %i<ssa value number>. Add support for future pretty printing of ML function arguments as %arg<ssa value number>.

Induction variables are implemented by inheriting ForStmt from MLValue. ForStmt provides APIs that make this design decision invisible to the ForStmt users.

This CL in combination with cl/206253643 resolves  http://b/111769060.

PiperOrigin-RevId: 206655937
This commit is contained in:
Tatiana Shpeisman 2018-07-30 15:18:10 -07:00 committed by jpienaar
parent fe7356c43b
commit c8b0273f19
8 changed files with 109 additions and 27 deletions

View File

@ -187,19 +187,17 @@ private:
};
/// For statement represents an affine loop nest.
class ForStmt : public Statement, public StmtBlock {
class ForStmt : public Statement, public StmtBlock, private MLValue {
public:
// TODO: lower and upper bounds should be affine maps with
// dimension and symbol use lists.
explicit ForStmt(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound, AffineConstantExpr *step)
: Statement(Kind::For), StmtBlock(StmtBlockKind::For),
lowerBound(lowerBound), upperBound(upperBound), step(step) {}
AffineConstantExpr *upperBound, AffineConstantExpr *step,
MLIRContext *context);
// Loop bounds and step are immortal objects and don't need to be deleted.
~ForStmt() {}
// TODO: represent induction variable
AffineConstantExpr *getLowerBound() const { return lowerBound; }
AffineConstantExpr *getUpperBound() const { return upperBound; }
AffineConstantExpr *getStep() const { return step; }
@ -213,6 +211,16 @@ public:
return block->getStmtBlockKind() == StmtBlockKind::For;
}
// For statement represents induction variable by inheriting
// from MLValue. This design is hidden behind interfaces.
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::InductionVar;
}
/// MLValue methods
MLValue *getInductionVar() { return this; }
const MLValue *getInductionVar() const { return this; }
private:
AffineConstantExpr *lowerBound;
AffineConstantExpr *upperBound;

View File

@ -542,22 +542,40 @@ public:
protected:
void numberValueID(const SSAValue *value) {
assert(!valueIDs.count(value) && "Value numbered multiple times");
valueIDs[value] = nextValueID++;
unsigned id;
switch (value->getKind()) {
case SSAValueKind::BBArgument:
case SSAValueKind::InstResult:
case SSAValueKind::StmtResult:
id = nextValueID++;
break;
case SSAValueKind::FnArgument:
id = nextFnArgumentID++;
break;
case SSAValueKind::InductionVar:
id = nextInductionVarID++;
break;
}
valueIDs[value] = id;
}
void printValueID(const SSAValue *value,
bool dontPrintResultNo = false) const {
void printValueID(const SSAValue *value, bool printResultNo = true) const {
int resultNo = -1;
auto lookupValue = value;
// If this is a reference to the result of a multi-result instruction, print
// out the # identifier and make sure to map our lookup to the first result
// of the instruction.
// If this is a reference to the result of a multi-result instruction or
// statement, print out the # identifier and make sure to map our lookup
// to the first result of the instruction.
if (auto *result = dyn_cast<InstResult>(value)) {
if (result->getOwner()->getNumResults() != 1) {
resultNo = result->getResultNumber();
lookupValue = result->getOwner()->getResult(0);
}
} else if (auto *result = dyn_cast<StmtResult>(value)) {
if (result->getOwner()->getNumResults() != 1) {
resultNo = result->getResultNumber();
lookupValue = result->getOwner()->getResult(0);
}
}
auto it = valueIDs.find(lookupValue);
@ -566,8 +584,14 @@ protected:
return;
}
os << '%' << it->getSecond();
if (resultNo != -1 && !dontPrintResultNo)
os << '%';
if (isa<ForStmt>(value))
os << 'i';
else if (isa<FnArgument>(value))
os << "arg";
os << it->getSecond();
if (resultNo != -1 && printResultNo)
os << '#' << resultNo;
}
@ -575,12 +599,14 @@ private:
/// This is the value ID for each SSA value in the current function.
DenseMap<const SSAValue *, unsigned> valueIDs;
unsigned nextValueID = 0;
unsigned nextInductionVarID = 0;
unsigned nextFnArgumentID = 0;
};
} // end anonymous namespace
void FunctionPrinter::printOperation(const Operation *op) {
if (op->getNumResults()) {
printValueID(op->getResult(0), /*dontPrintResultNo*/ true);
printValueID(op->getResult(0), /*printResultNo=*/false);
os << " = ";
}
@ -874,6 +900,9 @@ void MLFunctionPrinter::numberValues() {
if (stmt->getNumResults() != 0)
printer->numberValueID(stmt->getResult(0));
}
void visitForStmt(ForStmt *stmt) {
printer->numberValueID(stmt->getInductionVar());
}
MLFunctionPrinter *printer;
};
@ -918,7 +947,9 @@ void MLFunctionPrinter::print(const OperationStmt *stmt) {
}
void MLFunctionPrinter::print(const ForStmt *stmt) {
os.indent(numSpaces) << "for x = " << *stmt->getLowerBound();
os.indent(numSpaces) << "for ";
printOperand(stmt->getInductionVar());
os << " = " << *stmt->getLowerBound();
os << " to " << *stmt->getUpperBound();
if (stmt->getStep()->getValue() != 1)
os << " step " << *stmt->getStep();

View File

@ -161,7 +161,7 @@ ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *step) {
if (!step)
step = getConstantExpr(1);
auto *stmt = new ForStmt(lowerBound, upperBound, step);
auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
block->getStatements().push_back(stmt);
return stmt;
}

View File

@ -18,6 +18,7 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
@ -190,6 +191,16 @@ OperationStmt *SSAValue::getDefiningStmt() {
return nullptr;
}
//===----------------------------------------------------------------------===//
// ForStmt
//===----------------------------------------------------------------------===//
ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
AffineConstantExpr *step, MLIRContext *context)
: Statement(Kind::For), StmtBlock(StmtBlockKind::For),
MLValue(MLValueKind::InductionVar, Type::getAffineInt(context)),
lowerBound(lowerBound), upperBound(upperBound), step(step) {}
//===----------------------------------------------------------------------===//
// IfStmt
//===----------------------------------------------------------------------===//

View File

@ -2086,10 +2086,8 @@ ParseResult MLFunctionParser::parseForStmt() {
if (getToken().isNot(Token::percent_identifier))
return emitError("expected SSA identifier for the loop variable");
// TODO: create SSA value definition from name
StringRef name = getTokenSpelling().drop_front();
(void)name;
auto loc = getToken().getLoc();
StringRef inductionVariableName = getTokenSpelling().drop_front();
consumeToken(Token::percent_identifier);
if (parseToken(Token::equal, "expected ="))
@ -2116,14 +2114,20 @@ ParseResult MLFunctionParser::parseForStmt() {
}
// Create for statement.
ForStmt *stmt = builder.createFor(lowerBound, upperBound, step);
ForStmt *forStmt = builder.createFor(lowerBound, upperBound, step);
// Create SSA value definition for the induction variable.
addDefinition({inductionVariableName, 0, loc}, forStmt->getInductionVar());
// If parsing of the for statement body fails,
// MLIR contains for statement with those nested statements that have been
// successfully parsed.
if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
if (parseStmtBlock(forStmt))
return ParseFailure;
// Reset insertion point to the current block.
builder.setInsertionPoint(forStmt->getBlock());
return ParseSuccess;
}
@ -2174,6 +2178,9 @@ ParseResult MLFunctionParser::parseIfStmt() {
return ParseFailure;
}
// Reset insertion point to the current block.
builder.setInsertionPoint(ifStmt->getBlock());
return ParseSuccess;
}

View File

@ -281,3 +281,12 @@ mlfunc @undef() {
return
}
// -----
mlfunc @duplicate_induction_var() {
for %i = 1 to 10 { // expected-error {{previously defined here}}
for %i = 1 to 10 { // expected-error {{redefinition of SSA value 'i'}}
}
}
return
}

View File

@ -116,7 +116,7 @@ mlfunc @mlfunc_with_args(%a : f16) {
mlfunc @mlfunc_with_ops() {
// CHECK: %0 = "foo"() : () -> i64
%a = "foo"() : ()->i64
// CHECK: for x = 1 to 10 {
// CHECK: for %i0 = 1 to 10 {
for %i = 1 to 10 {
// CHECK: %1 = "doo"() : () -> f32
%b = "doo"() : ()->f32
@ -132,18 +132,34 @@ mlfunc @mlfunc_with_ops() {
// CHECK-LABEL: mlfunc @loops() {
mlfunc @loops() {
// CHECK: for x = 1 to 100 step 2 {
// CHECK: for %i0 = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
// CHECK: for x = 1 to 200 {
// CHECK: for %i1 = 1 to 200 {
for %j = 1 to 200 {
} // CHECK: }
} // CHECK: }
return // CHECK: return
} // CHECK: }
// CHECK-LABEL: mlfunc @complex_loops() {
mlfunc @complex_loops() {
for %i1 = 1 to 100 { // CHECK: for %i0 = 1 to 100 {
for %j1 = 1 to 100 { // CHECK: for %i1 = 1 to 100 {
"foo"() : () -> () // CHECK: "foo"() : () -> ()
} // CHECK: }
"boo"() : () -> () // CHECK: "boo"() : () -> ()
for %j2 = 1 to 10 { // CHECK: for %i2 = 1 to 10 {
for %k2 = 1 to 10 { // CHECK: for %i3 = 1 to 10 {
"goo"() : () -> () // CHECK: "goo"() : () -> ()
} // CHECK: }
} // CHECK: }
} // CHECK: }
return // CHECK: return
} // CHECK: }
// CHECK-LABEL: mlfunc @ifstmt() {
mlfunc @ifstmt() {
for %i = 1 to 10 { // CHECK for x = 1 to 10 {
for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
if () { // CHECK if () {
} else if () { // CHECK } else if () {
} else { // CHECK } else {

View File

@ -2,7 +2,7 @@
// CHECK-LABEL: mlfunc @loops() {
mlfunc @loops() {
// CHECK: for x = 1 to 100 step 2 {
// CHECK: for %i0 = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
// CHECK: "custom"(){value: 1} : () -> ()
// CHECK-NEXT: "custom"(){value: 1} : () -> ()