MLStmt cloning and IV replacement for loop unrolling, add constant pool to

MLFunctions.

- MLStmt cloning and IV replacement
- While at this, fix the innermostLoopGatherer to actually gather all the
  innermost loops (it was stopping its walk at the first innermost loop it
  found)
- Improve comments for MLFunction statement classes, fix inheritance order.

- Fixed StmtBlock destructor.

PiperOrigin-RevId: 207049173
This commit is contained in:
Uday Bondhugula 2018-08-01 22:36:12 -07:00 committed by jpienaar
parent b92378e8fa
commit 2a003256ae
9 changed files with 188 additions and 43 deletions

View File

@ -18,6 +18,7 @@
#ifndef MLIR_IR_BUILDERS_H
#define MLIR_IR_BUILDERS_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Statements.h"
@ -162,6 +163,12 @@ public:
return op;
}
OperationInst *cloneOperation(const OperationInst &srcOpInst) {
auto *op = srcOpInst.clone();
block->getOperations().insert(insertPoint, op);
return op;
}
// Terminators.
ReturnInst *createReturnInst(ArrayRef<CFGValue *> operands) {
@ -232,6 +239,12 @@ public:
insertPoint = block->end();
}
/// Set the insertion point at the beginning of the specified block.
void setInsertionPointAtStart(StmtBlock *block) {
this->block = block;
insertPoint = block->begin();
}
OperationStmt *createOperation(Identifier name, ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes) {
@ -241,6 +254,12 @@ public:
return op;
}
OperationStmt *cloneOperation(const OperationStmt &srcOpStmt) {
auto *op = srcOpStmt.clone();
block->getStatements().insert(insertPoint, op);
return op;
}
// Creates for statement. When step is not specified, it is set to 1.
ForStmt *createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
@ -252,6 +271,15 @@ public:
return stmt;
}
// TODO: subsume with a generate create<ConstantInt>() method.
OperationStmt *createConstInt32Op(int value) {
std::pair<Identifier, Attribute *> namedAttr(
Identifier::get("value", context), getIntegerAttr(value));
auto *mlconst = createOperation(Identifier::get("constant", context), {},
{getIntegerType(32)}, {namedAttr});
return mlconst;
}
private:
StmtBlock *block = nullptr;
StmtBlock::iterator insertPoint;

View File

@ -158,6 +158,8 @@ public:
/// Return the context this operation is associated with.
MLIRContext *getContext() const { return Instruction::getContext(); }
OperationInst *clone() const;
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//

View File

@ -34,8 +34,8 @@ class MLIRContext;
/// Statement is a basic unit of execution within an ML function.
/// Statements can be nested within for and if statements effectively
/// forming a tree. Statements are organized into statement blocks
/// represented by StmtBlock class.
/// forming a tree. Child statements are organized into statement blocks
/// represented by a 'StmtBlock' class.
class Statement : public llvm::ilist_node_with_parent<Statement, StmtBlock> {
public:
enum class Kind {
@ -77,6 +77,7 @@ protected:
private:
Kind kind;
/// The statement block that containts this statement.
StmtBlock *block = nullptr;
// allow ilist_traits access to 'block' field.

View File

@ -47,6 +47,8 @@ public:
/// Return the context this operation is associated with.
MLIRContext *getContext() const;
OperationStmt *clone() const;
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//
@ -190,7 +192,7 @@ private:
};
/// For statement represents an affine loop nest.
class ForStmt : public Statement, public StmtBlock, public MLValue {
class ForStmt : public Statement, public MLValue, public StmtBlock {
public:
// TODO: lower and upper bounds should be affine maps with
// dimension and symbol use lists.
@ -199,6 +201,10 @@ public:
MLIRContext *context);
// Loop bounds and step are immortal objects and don't need to be deleted.
// With this dtor, ForStmt needs to inherit from MLValue before it does from
// StmtBlock since an MLValue can't be destroyed before the StmtBlock is ---
// the latter has uses for the induction variables, which is actually the
// MLValue here. FIXME: this dtor.
~ForStmt() {}
AffineConstantExpr *getLowerBound() const { return lowerBound; }

View File

@ -26,10 +26,12 @@
#include "mlir/IR/Statement.h"
namespace mlir {
class MLFunction;
class IfStmt;
class MLFunction;
class IfStmt;
/// Statement block represents an ordered list of statements.
/// Statement block represents an ordered list of statements, with the order
/// being the contiguous lexical order in which the statements appear as
/// children of a parent statement in the ML Function.
class StmtBlock {
public:
enum class StmtBlockKind {
@ -54,7 +56,7 @@ public:
/// This is the list of statements in the block.
typedef llvm::iplist<Statement> StmtListType;
StmtListType &getStatements() { return statements; }
StmtListType &getStatements() { return statements; }
const StmtListType &getStatements() const { return statements; }
// Iteration over the statements in the block.
@ -82,14 +84,14 @@ public:
}
Statement &front() { return statements.front(); }
const Statement &front() const {
return const_cast<StmtBlock*>(this)->front();
return const_cast<StmtBlock *>(this)->front();
}
void print(raw_ostream &os) const;
void dump() const;
/// getSublistAccess() - Returns pointer to member of statement list
static StmtListType StmtBlock::*getSublistAccess(Statement*) {
static StmtListType StmtBlock::*getSublistAccess(Statement *) {
return &StmtBlock::statements;
}
@ -101,9 +103,8 @@ private:
/// This is the list of statements in the block.
StmtListType statements;
StmtBlock(const StmtBlock&) = delete;
void operator=(const StmtBlock&) = delete;
StmtBlock(const StmtBlock &) = delete;
void operator=(const StmtBlock &) = delete;
};
} //end namespace mlir

View File

@ -145,6 +145,21 @@ OperationInst *OperationInst::create(Identifier name,
return inst;
}
OperationInst *OperationInst::clone() const {
SmallVector<CFGValue *, 8> operands;
SmallVector<Type *, 8> resultTypes;
// TODO(clattner): switch to iterator logic.
// Put together the operands and results.
for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
operands.push_back(getInstOperand(i).get());
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
resultTypes.push_back(getInstResult(i).getType());
return create(getName(), operands, resultTypes, getAttrs(), getContext());
}
OperationInst::OperationInst(Identifier name, unsigned numOperands,
unsigned numResults,
ArrayRef<NamedAttribute> attributes,

View File

@ -120,7 +120,6 @@ void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
/// Remove this statement (and its descendants) from its StmtBlock and delete
/// all of them.
/// TODO: erase all descendents for ForStmt/IfStmt.
void Statement::eraseFromBlock() {
assert(getBlock() && "Statement has no block");
getBlock()->getStatements().erase(this);
@ -155,6 +154,22 @@ OperationStmt *OperationStmt::create(Identifier name,
return stmt;
}
/// Clone an existing OperationStmt.
OperationStmt *OperationStmt::clone() const {
SmallVector<MLValue *, 8> operands;
SmallVector<Type *, 8> resultTypes;
// TODO(clattner): switch this to iterator logic.
// Put together operands and results.
for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
operands.push_back(getStmtOperand(i).get());
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
resultTypes.push_back(getStmtResult(i).getType());
return create(getName(), operands, resultTypes, getAttrs(), getContext());
}
OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
unsigned numResults,
ArrayRef<NamedAttribute> attributes,
@ -205,9 +220,10 @@ void OperationStmt::dropAllReferences() {
ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
AffineConstantExpr *step, MLIRContext *context)
: Statement(Kind::For), StmtBlock(StmtBlockKind::For),
: Statement(Kind::For),
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
lowerBound(lowerBound), upperBound(upperBound), step(step) {}
StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
upperBound(upperBound), step(step) {}
//===----------------------------------------------------------------------===//
// IfStmt
@ -215,6 +231,6 @@ ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
IfStmt::~IfStmt() {
delete thenClause;
if (elseClause != nullptr)
if (elseClause)
delete elseClause;
}

View File

@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
@ -54,10 +55,13 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
typedef llvm::iplist<Statement> StmtListType;
bool walkPostOrder(StmtListType::iterator Start,
StmtListType::iterator End) {
bool hasInnerLoops = false;
// We need to walk all elements since all innermost loops need to be
// gathered as opposed to determining whether this list has any inner
// loops or not.
while (Start != End)
if (walkPostOrder(&(*Start++)))
return true;
return false;
hasInnerLoops |= walkPostOrder(&(*Start++));
return hasInnerLoops;
}
// FIXME: can't use base class method for this because that in turn would
@ -73,12 +77,11 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
}
bool walkIfStmtPostOrder(IfStmt *ifStmt) {
if (walkPostOrder(ifStmt->getThenClause()->begin(),
ifStmt->getThenClause()->end()) ||
walkPostOrder(ifStmt->getElseClause()->begin(),
ifStmt->getElseClause()->end()))
return true;
return false;
bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
ifStmt->getThenClause()->end());
hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
ifStmt->getElseClause()->end());
return hasInnerLoops;
}
bool walkOpStmt(OperationStmt *opStmt) { return false; }
@ -93,17 +96,45 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
runOnForStmt(forStmt);
}
/// Unrolls this loop completely. Returns true if the unrolling happens.
/// Replace an IV with a constant value.
static void replaceIterator(Statement *stmt, const ForStmt &iv,
MLValue *constVal) {
struct ReplaceIterator : public StmtWalker<ReplaceIterator> {
// IV to be replaced.
const ForStmt *iv;
// Constant to be replaced with.
MLValue *constVal;
ReplaceIterator(const ForStmt &iv, MLValue *constVal)
: iv(&iv), constVal(constVal){};
void visitOperationStmt(OperationStmt *os) {
for (auto &operand : os->getStmtOperands()) {
if (operand.get() == static_cast<const MLValue *>(iv)) {
operand.set(constVal);
}
}
}
};
ReplaceIterator ri(iv, constVal);
ri.walk(stmt);
}
/// Unrolls this loop completely.
void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep()->getValue();
auto trip_count = (ub - lb + 1) / step;
auto *block = forStmt->getBlock();
MLFuncBuilder builder(block);
auto *mlFunc = forStmt->Statement::findFunction();
MLFuncBuilder funcTopBuilder(mlFunc);
funcTopBuilder.setInsertionPointAtStart(mlFunc);
MLFuncBuilder builder(forStmt->getBlock());
for (int i = 0; i < trip_count; i++) {
auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0);
for (auto &stmt : forStmt->getStatements()) {
switch (stmt.getKind()) {
case Statement::Kind::For:
@ -113,16 +144,13 @@ void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
llvm_unreachable("unrolling loops that have only operations");
break;
case Statement::Kind::Operation:
auto *op = cast<OperationStmt>(&stmt);
// TODO: clone operands and result types.
builder.createOperation(op->getName(), /*operands*/ {},
/*resultTypes*/ {}, op->getAttrs());
// TODO: loop iterator parsing not yet implemented; replace loop
// iterator uses in unrolled body appropriately.
auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt));
// TODO(bondhugula): only generate constants when the IV actually
// appears in the body.
replaceIterator(cloneOp, *forStmt, ivUnrolledVal);
break;
}
}
}
forStmt->eraseFromBlock();
}

View File

@ -1,16 +1,64 @@
// RUN: %S/../../mlir-opt %s -o - -unroll-innermost-loops | FileCheck %s
// CHECK-LABEL: mlfunc @loops() {
mlfunc @loops() {
// CHECK: for %i0 = 1 to 100 step 2 {
// CHECK-LABEL: mlfunc @loops1() {
mlfunc @loops1() {
// CHECK: %c0_i32 = constant 0 : i32
// CHECK-NEXT: %c1_i32 = constant 1 : i32
// CHECK-NEXT: %c2_i32 = constant 2 : i32
// CHECK-NEXT: %c3_i32 = constant 3 : i32
// CHECK-NEXT: for %i0 = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
// CHECK: "custom"(){value: 1} : () -> ()
// CHECK-NEXT: "custom"(){value: 1} : () -> ()
// CHECK-NEXT: "custom"(){value: 1} : () -> ()
// CHECK-NEXT: "custom"(){value: 1} : () -> ()
// CHECK: %c1_i32_0 = constant 1 : i32
// CHECK-NEXT: %c1_i32_1 = constant 1 : i32
// CHECK-NEXT: %c1_i32_2 = constant 1 : i32
// CHECK-NEXT: %c1_i32_3 = constant 1 : i32
for %j = 1 to 4 {
"custom"(){value: 1} : () -> f32
%x = constant 1 : i32
}
} // CHECK: }
return // CHECK: return
} // CHECK }
// CHECK-LABEL: mlfunc @loops2() {
mlfunc @loops2() {
// CHECK: %c0_i32 = constant 0 : i32
// CHECK-NEXT: %c1_i32 = constant 1 : i32
// CHECK-NEXT: %c2_i32 = constant 2 : i32
// CHECK-NEXT: %c3_i32 = constant 3 : i32
// CHECK-NEXT: %c0_i32_0 = constant 0 : i32
// CHECK-NEXT: %c1_i32_1 = constant 1 : i32
// CHECK-NEXT: %c2_i32_2 = constant 2 : i32
// CHECK-NEXT: %c3_i32_3 = constant 3 : i32
// CHECK-NEXT: for %i0 = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
// CHECK: %0 = affine_apply (d0) -> (d0 + 1)(%c0_i32_0)
// CHECK-NEXT: %1 = affine_apply (d0) -> (d0 + 1)(%c1_i32_1)
// CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c2_i32_2)
// CHECK-NEXT: %3 = affine_apply (d0) -> (d0 + 1)(%c3_i32_3)
for %j = 1 to 4 {
%x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
(affineint) -> (affineint)
}
} // CHECK: }
// CHECK: %c99 = constant 99 : affineint
%k = "constant"(){value: 99} : () -> affineint
// CHECK: for %i1 = 1 to 100 step 2 {
for %m = 1 to 100 step 2 {
// CHECK: %4 = affine_apply (d0) -> (d0 + 1)(%c0_i32)
// CHECK-NEXT: %5 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c0_i32)[%c99]
// CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c1_i32)
// CHECK-NEXT: %7 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c1_i32)[%c99]
// CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c2_i32)
// CHECK-NEXT: %9 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c2_i32)[%c99]
// CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
// CHECK-NEXT: %11 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c3_i32)[%c99]
for %n = 1 to 4 {
%y = "affine_apply" (%n) { map: (d0) -> (d0 + 1) } :
(affineint) -> (affineint)
%z = "affine_apply" (%n, %k) { map: (d0) [s0] -> (d0 + s0 + 1) } :
(affineint, affineint) -> (affineint)
} // CHECK }
} // CHECK }
return // CHECK: return
} // CHECK }