forked from OSchip/llvm-project
Implement induction variables. Pretty print induction variable operands as %i<ssa value number>. Add support for future pretty printing of ML function arguments as %arg<ssa value number>.
Induction variables are implemented by inheriting ForStmt from MLValue. ForStmt provides APIs that make this design decision invisible to the ForStmt users. This CL in combination with cl/206253643 resolves http://b/111769060. PiperOrigin-RevId: 206655937
This commit is contained in:
parent
fe7356c43b
commit
c8b0273f19
|
@ -187,19 +187,17 @@ private:
|
|||
};
|
||||
|
||||
/// For statement represents an affine loop nest.
|
||||
class ForStmt : public Statement, public StmtBlock {
|
||||
class ForStmt : public Statement, public StmtBlock, private MLValue {
|
||||
public:
|
||||
// TODO: lower and upper bounds should be affine maps with
|
||||
// dimension and symbol use lists.
|
||||
explicit ForStmt(AffineConstantExpr *lowerBound,
|
||||
AffineConstantExpr *upperBound, AffineConstantExpr *step)
|
||||
: Statement(Kind::For), StmtBlock(StmtBlockKind::For),
|
||||
lowerBound(lowerBound), upperBound(upperBound), step(step) {}
|
||||
AffineConstantExpr *upperBound, AffineConstantExpr *step,
|
||||
MLIRContext *context);
|
||||
|
||||
// Loop bounds and step are immortal objects and don't need to be deleted.
|
||||
~ForStmt() {}
|
||||
|
||||
// TODO: represent induction variable
|
||||
AffineConstantExpr *getLowerBound() const { return lowerBound; }
|
||||
AffineConstantExpr *getUpperBound() const { return upperBound; }
|
||||
AffineConstantExpr *getStep() const { return step; }
|
||||
|
@ -213,6 +211,16 @@ public:
|
|||
return block->getStmtBlockKind() == StmtBlockKind::For;
|
||||
}
|
||||
|
||||
// For statement represents induction variable by inheriting
|
||||
// from MLValue. This design is hidden behind interfaces.
|
||||
static bool classof(const SSAValue *value) {
|
||||
return value->getKind() == SSAValueKind::InductionVar;
|
||||
}
|
||||
|
||||
/// MLValue methods
|
||||
MLValue *getInductionVar() { return this; }
|
||||
const MLValue *getInductionVar() const { return this; }
|
||||
|
||||
private:
|
||||
AffineConstantExpr *lowerBound;
|
||||
AffineConstantExpr *upperBound;
|
||||
|
|
|
@ -542,22 +542,40 @@ public:
|
|||
protected:
|
||||
void numberValueID(const SSAValue *value) {
|
||||
assert(!valueIDs.count(value) && "Value numbered multiple times");
|
||||
valueIDs[value] = nextValueID++;
|
||||
unsigned id;
|
||||
switch (value->getKind()) {
|
||||
case SSAValueKind::BBArgument:
|
||||
case SSAValueKind::InstResult:
|
||||
case SSAValueKind::StmtResult:
|
||||
id = nextValueID++;
|
||||
break;
|
||||
case SSAValueKind::FnArgument:
|
||||
id = nextFnArgumentID++;
|
||||
break;
|
||||
case SSAValueKind::InductionVar:
|
||||
id = nextInductionVarID++;
|
||||
break;
|
||||
}
|
||||
valueIDs[value] = id;
|
||||
}
|
||||
|
||||
void printValueID(const SSAValue *value,
|
||||
bool dontPrintResultNo = false) const {
|
||||
void printValueID(const SSAValue *value, bool printResultNo = true) const {
|
||||
int resultNo = -1;
|
||||
auto lookupValue = value;
|
||||
|
||||
// If this is a reference to the result of a multi-result instruction, print
|
||||
// out the # identifier and make sure to map our lookup to the first result
|
||||
// of the instruction.
|
||||
// If this is a reference to the result of a multi-result instruction or
|
||||
// statement, print out the # identifier and make sure to map our lookup
|
||||
// to the first result of the instruction.
|
||||
if (auto *result = dyn_cast<InstResult>(value)) {
|
||||
if (result->getOwner()->getNumResults() != 1) {
|
||||
resultNo = result->getResultNumber();
|
||||
lookupValue = result->getOwner()->getResult(0);
|
||||
}
|
||||
} else if (auto *result = dyn_cast<StmtResult>(value)) {
|
||||
if (result->getOwner()->getNumResults() != 1) {
|
||||
resultNo = result->getResultNumber();
|
||||
lookupValue = result->getOwner()->getResult(0);
|
||||
}
|
||||
}
|
||||
|
||||
auto it = valueIDs.find(lookupValue);
|
||||
|
@ -566,8 +584,14 @@ protected:
|
|||
return;
|
||||
}
|
||||
|
||||
os << '%' << it->getSecond();
|
||||
if (resultNo != -1 && !dontPrintResultNo)
|
||||
os << '%';
|
||||
if (isa<ForStmt>(value))
|
||||
|
||||
os << 'i';
|
||||
else if (isa<FnArgument>(value))
|
||||
os << "arg";
|
||||
os << it->getSecond();
|
||||
if (resultNo != -1 && printResultNo)
|
||||
os << '#' << resultNo;
|
||||
}
|
||||
|
||||
|
@ -575,12 +599,14 @@ private:
|
|||
/// This is the value ID for each SSA value in the current function.
|
||||
DenseMap<const SSAValue *, unsigned> valueIDs;
|
||||
unsigned nextValueID = 0;
|
||||
unsigned nextInductionVarID = 0;
|
||||
unsigned nextFnArgumentID = 0;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void FunctionPrinter::printOperation(const Operation *op) {
|
||||
if (op->getNumResults()) {
|
||||
printValueID(op->getResult(0), /*dontPrintResultNo*/ true);
|
||||
printValueID(op->getResult(0), /*printResultNo=*/false);
|
||||
os << " = ";
|
||||
}
|
||||
|
||||
|
@ -874,6 +900,9 @@ void MLFunctionPrinter::numberValues() {
|
|||
if (stmt->getNumResults() != 0)
|
||||
printer->numberValueID(stmt->getResult(0));
|
||||
}
|
||||
void visitForStmt(ForStmt *stmt) {
|
||||
printer->numberValueID(stmt->getInductionVar());
|
||||
}
|
||||
MLFunctionPrinter *printer;
|
||||
};
|
||||
|
||||
|
@ -918,7 +947,9 @@ void MLFunctionPrinter::print(const OperationStmt *stmt) {
|
|||
}
|
||||
|
||||
void MLFunctionPrinter::print(const ForStmt *stmt) {
|
||||
os.indent(numSpaces) << "for x = " << *stmt->getLowerBound();
|
||||
os.indent(numSpaces) << "for ";
|
||||
printOperand(stmt->getInductionVar());
|
||||
os << " = " << *stmt->getLowerBound();
|
||||
os << " to " << *stmt->getUpperBound();
|
||||
if (stmt->getStep()->getValue() != 1)
|
||||
os << " step " << *stmt->getStep();
|
||||
|
|
|
@ -161,7 +161,7 @@ ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
|
|||
AffineConstantExpr *step) {
|
||||
if (!step)
|
||||
step = getConstantExpr(1);
|
||||
auto *stmt = new ForStmt(lowerBound, upperBound, step);
|
||||
auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
|
||||
block->getStatements().push_back(stmt);
|
||||
return stmt;
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/IR/MLFunction.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -190,6 +191,16 @@ OperationStmt *SSAValue::getDefiningStmt() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ForStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
|
||||
AffineConstantExpr *step, MLIRContext *context)
|
||||
: Statement(Kind::For), StmtBlock(StmtBlockKind::For),
|
||||
MLValue(MLValueKind::InductionVar, Type::getAffineInt(context)),
|
||||
lowerBound(lowerBound), upperBound(upperBound), step(step) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -2086,10 +2086,8 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
if (getToken().isNot(Token::percent_identifier))
|
||||
return emitError("expected SSA identifier for the loop variable");
|
||||
|
||||
// TODO: create SSA value definition from name
|
||||
StringRef name = getTokenSpelling().drop_front();
|
||||
(void)name;
|
||||
|
||||
auto loc = getToken().getLoc();
|
||||
StringRef inductionVariableName = getTokenSpelling().drop_front();
|
||||
consumeToken(Token::percent_identifier);
|
||||
|
||||
if (parseToken(Token::equal, "expected ="))
|
||||
|
@ -2116,14 +2114,20 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
}
|
||||
|
||||
// Create for statement.
|
||||
ForStmt *stmt = builder.createFor(lowerBound, upperBound, step);
|
||||
ForStmt *forStmt = builder.createFor(lowerBound, upperBound, step);
|
||||
|
||||
// Create SSA value definition for the induction variable.
|
||||
addDefinition({inductionVariableName, 0, loc}, forStmt->getInductionVar());
|
||||
|
||||
// If parsing of the for statement body fails,
|
||||
// MLIR contains for statement with those nested statements that have been
|
||||
// successfully parsed.
|
||||
if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
|
||||
if (parseStmtBlock(forStmt))
|
||||
return ParseFailure;
|
||||
|
||||
// Reset insertion point to the current block.
|
||||
builder.setInsertionPoint(forStmt->getBlock());
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
|
@ -2174,6 +2178,9 @@ ParseResult MLFunctionParser::parseIfStmt() {
|
|||
return ParseFailure;
|
||||
}
|
||||
|
||||
// Reset insertion point to the current block.
|
||||
builder.setInsertionPoint(ifStmt->getBlock());
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
|
|
|
@ -281,3 +281,12 @@ mlfunc @undef() {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mlfunc @duplicate_induction_var() {
|
||||
for %i = 1 to 10 { // expected-error {{previously defined here}}
|
||||
for %i = 1 to 10 { // expected-error {{redefinition of SSA value 'i'}}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -116,7 +116,7 @@ mlfunc @mlfunc_with_args(%a : f16) {
|
|||
mlfunc @mlfunc_with_ops() {
|
||||
// CHECK: %0 = "foo"() : () -> i64
|
||||
%a = "foo"() : ()->i64
|
||||
// CHECK: for x = 1 to 10 {
|
||||
// CHECK: for %i0 = 1 to 10 {
|
||||
for %i = 1 to 10 {
|
||||
// CHECK: %1 = "doo"() : () -> f32
|
||||
%b = "doo"() : ()->f32
|
||||
|
@ -132,18 +132,34 @@ mlfunc @mlfunc_with_ops() {
|
|||
|
||||
// CHECK-LABEL: mlfunc @loops() {
|
||||
mlfunc @loops() {
|
||||
// CHECK: for x = 1 to 100 step 2 {
|
||||
// CHECK: for %i0 = 1 to 100 step 2 {
|
||||
for %i = 1 to 100 step 2 {
|
||||
// CHECK: for x = 1 to 200 {
|
||||
// CHECK: for %i1 = 1 to 200 {
|
||||
for %j = 1 to 200 {
|
||||
} // CHECK: }
|
||||
} // CHECK: }
|
||||
return // CHECK: return
|
||||
} // CHECK: }
|
||||
|
||||
// CHECK-LABEL: mlfunc @complex_loops() {
|
||||
mlfunc @complex_loops() {
|
||||
for %i1 = 1 to 100 { // CHECK: for %i0 = 1 to 100 {
|
||||
for %j1 = 1 to 100 { // CHECK: for %i1 = 1 to 100 {
|
||||
"foo"() : () -> () // CHECK: "foo"() : () -> ()
|
||||
} // CHECK: }
|
||||
"boo"() : () -> () // CHECK: "boo"() : () -> ()
|
||||
for %j2 = 1 to 10 { // CHECK: for %i2 = 1 to 10 {
|
||||
for %k2 = 1 to 10 { // CHECK: for %i3 = 1 to 10 {
|
||||
"goo"() : () -> () // CHECK: "goo"() : () -> ()
|
||||
} // CHECK: }
|
||||
} // CHECK: }
|
||||
} // CHECK: }
|
||||
return // CHECK: return
|
||||
} // CHECK: }
|
||||
|
||||
// CHECK-LABEL: mlfunc @ifstmt() {
|
||||
mlfunc @ifstmt() {
|
||||
for %i = 1 to 10 { // CHECK for x = 1 to 10 {
|
||||
for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
|
||||
if () { // CHECK if () {
|
||||
} else if () { // CHECK } else if () {
|
||||
} else { // CHECK } else {
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
// CHECK-LABEL: mlfunc @loops() {
|
||||
mlfunc @loops() {
|
||||
// CHECK: for x = 1 to 100 step 2 {
|
||||
// CHECK: for %i0 = 1 to 100 step 2 {
|
||||
for %i = 1 to 100 step 2 {
|
||||
// CHECK: "custom"(){value: 1} : () -> ()
|
||||
// CHECK-NEXT: "custom"(){value: 1} : () -> ()
|
||||
|
|
Loading…
Reference in New Issue