forked from OSchip/llvm-project
MLStmt cloning and IV replacement for loop unrolling, add constant pool to
MLFunctions. - MLStmt cloning and IV replacement - While at this, fix the innermostLoopGatherer to actually gather all the innermost loops (it was stopping its walk at the first innermost loop it found) - Improve comments for MLFunction statement classes, fix inheritance order. - Fixed StmtBlock destructor. PiperOrigin-RevId: 207049173
This commit is contained in:
parent
b92378e8fa
commit
2a003256ae
|
@ -18,6 +18,7 @@
|
|||
#ifndef MLIR_IR_BUILDERS_H
|
||||
#define MLIR_IR_BUILDERS_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/CFGFunction.h"
|
||||
#include "mlir/IR/MLFunction.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
|
@ -162,6 +163,12 @@ public:
|
|||
return op;
|
||||
}
|
||||
|
||||
OperationInst *cloneOperation(const OperationInst &srcOpInst) {
|
||||
auto *op = srcOpInst.clone();
|
||||
block->getOperations().insert(insertPoint, op);
|
||||
return op;
|
||||
}
|
||||
|
||||
// Terminators.
|
||||
|
||||
ReturnInst *createReturnInst(ArrayRef<CFGValue *> operands) {
|
||||
|
@ -232,6 +239,12 @@ public:
|
|||
insertPoint = block->end();
|
||||
}
|
||||
|
||||
/// Set the insertion point at the beginning of the specified block.
|
||||
void setInsertionPointAtStart(StmtBlock *block) {
|
||||
this->block = block;
|
||||
insertPoint = block->begin();
|
||||
}
|
||||
|
||||
OperationStmt *createOperation(Identifier name, ArrayRef<MLValue *> operands,
|
||||
ArrayRef<Type *> resultTypes,
|
||||
ArrayRef<NamedAttribute> attributes) {
|
||||
|
@ -241,6 +254,12 @@ public:
|
|||
return op;
|
||||
}
|
||||
|
||||
OperationStmt *cloneOperation(const OperationStmt &srcOpStmt) {
|
||||
auto *op = srcOpStmt.clone();
|
||||
block->getStatements().insert(insertPoint, op);
|
||||
return op;
|
||||
}
|
||||
|
||||
// Creates for statement. When step is not specified, it is set to 1.
|
||||
ForStmt *createFor(AffineConstantExpr *lowerBound,
|
||||
AffineConstantExpr *upperBound,
|
||||
|
@ -252,6 +271,15 @@ public:
|
|||
return stmt;
|
||||
}
|
||||
|
||||
// TODO: subsume with a generate create<ConstantInt>() method.
|
||||
OperationStmt *createConstInt32Op(int value) {
|
||||
std::pair<Identifier, Attribute *> namedAttr(
|
||||
Identifier::get("value", context), getIntegerAttr(value));
|
||||
auto *mlconst = createOperation(Identifier::get("constant", context), {},
|
||||
{getIntegerType(32)}, {namedAttr});
|
||||
return mlconst;
|
||||
}
|
||||
|
||||
private:
|
||||
StmtBlock *block = nullptr;
|
||||
StmtBlock::iterator insertPoint;
|
||||
|
|
|
@ -158,6 +158,8 @@ public:
|
|||
/// Return the context this operation is associated with.
|
||||
MLIRContext *getContext() const { return Instruction::getContext(); }
|
||||
|
||||
OperationInst *clone() const;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operands
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -34,8 +34,8 @@ class MLIRContext;
|
|||
|
||||
/// Statement is a basic unit of execution within an ML function.
|
||||
/// Statements can be nested within for and if statements effectively
|
||||
/// forming a tree. Statements are organized into statement blocks
|
||||
/// represented by StmtBlock class.
|
||||
/// forming a tree. Child statements are organized into statement blocks
|
||||
/// represented by a 'StmtBlock' class.
|
||||
class Statement : public llvm::ilist_node_with_parent<Statement, StmtBlock> {
|
||||
public:
|
||||
enum class Kind {
|
||||
|
@ -77,6 +77,7 @@ protected:
|
|||
|
||||
private:
|
||||
Kind kind;
|
||||
/// The statement block that containts this statement.
|
||||
StmtBlock *block = nullptr;
|
||||
|
||||
// allow ilist_traits access to 'block' field.
|
||||
|
|
|
@ -47,6 +47,8 @@ public:
|
|||
/// Return the context this operation is associated with.
|
||||
MLIRContext *getContext() const;
|
||||
|
||||
OperationStmt *clone() const;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operands
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -190,7 +192,7 @@ private:
|
|||
};
|
||||
|
||||
/// For statement represents an affine loop nest.
|
||||
class ForStmt : public Statement, public StmtBlock, public MLValue {
|
||||
class ForStmt : public Statement, public MLValue, public StmtBlock {
|
||||
public:
|
||||
// TODO: lower and upper bounds should be affine maps with
|
||||
// dimension and symbol use lists.
|
||||
|
@ -199,6 +201,10 @@ public:
|
|||
MLIRContext *context);
|
||||
|
||||
// Loop bounds and step are immortal objects and don't need to be deleted.
|
||||
// With this dtor, ForStmt needs to inherit from MLValue before it does from
|
||||
// StmtBlock since an MLValue can't be destroyed before the StmtBlock is ---
|
||||
// the latter has uses for the induction variables, which is actually the
|
||||
// MLValue here. FIXME: this dtor.
|
||||
~ForStmt() {}
|
||||
|
||||
AffineConstantExpr *getLowerBound() const { return lowerBound; }
|
||||
|
|
|
@ -26,10 +26,12 @@
|
|||
#include "mlir/IR/Statement.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLFunction;
|
||||
class IfStmt;
|
||||
class MLFunction;
|
||||
class IfStmt;
|
||||
|
||||
/// Statement block represents an ordered list of statements.
|
||||
/// Statement block represents an ordered list of statements, with the order
|
||||
/// being the contiguous lexical order in which the statements appear as
|
||||
/// children of a parent statement in the ML Function.
|
||||
class StmtBlock {
|
||||
public:
|
||||
enum class StmtBlockKind {
|
||||
|
@ -54,7 +56,7 @@ public:
|
|||
|
||||
/// This is the list of statements in the block.
|
||||
typedef llvm::iplist<Statement> StmtListType;
|
||||
StmtListType &getStatements() { return statements; }
|
||||
StmtListType &getStatements() { return statements; }
|
||||
const StmtListType &getStatements() const { return statements; }
|
||||
|
||||
// Iteration over the statements in the block.
|
||||
|
@ -82,14 +84,14 @@ public:
|
|||
}
|
||||
Statement &front() { return statements.front(); }
|
||||
const Statement &front() const {
|
||||
return const_cast<StmtBlock*>(this)->front();
|
||||
return const_cast<StmtBlock *>(this)->front();
|
||||
}
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
/// getSublistAccess() - Returns pointer to member of statement list
|
||||
static StmtListType StmtBlock::*getSublistAccess(Statement*) {
|
||||
static StmtListType StmtBlock::*getSublistAccess(Statement *) {
|
||||
return &StmtBlock::statements;
|
||||
}
|
||||
|
||||
|
@ -101,9 +103,8 @@ private:
|
|||
/// This is the list of statements in the block.
|
||||
StmtListType statements;
|
||||
|
||||
StmtBlock(const StmtBlock&) = delete;
|
||||
void operator=(const StmtBlock&) = delete;
|
||||
|
||||
StmtBlock(const StmtBlock &) = delete;
|
||||
void operator=(const StmtBlock &) = delete;
|
||||
};
|
||||
|
||||
} //end namespace mlir
|
||||
|
|
|
@ -145,6 +145,21 @@ OperationInst *OperationInst::create(Identifier name,
|
|||
return inst;
|
||||
}
|
||||
|
||||
OperationInst *OperationInst::clone() const {
|
||||
SmallVector<CFGValue *, 8> operands;
|
||||
SmallVector<Type *, 8> resultTypes;
|
||||
|
||||
// TODO(clattner): switch to iterator logic.
|
||||
// Put together the operands and results.
|
||||
for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
|
||||
operands.push_back(getInstOperand(i).get());
|
||||
|
||||
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
|
||||
resultTypes.push_back(getInstResult(i).getType());
|
||||
|
||||
return create(getName(), operands, resultTypes, getAttrs(), getContext());
|
||||
}
|
||||
|
||||
OperationInst::OperationInst(Identifier name, unsigned numOperands,
|
||||
unsigned numResults,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
|
|
|
@ -120,7 +120,6 @@ void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
|
|||
|
||||
/// Remove this statement (and its descendants) from its StmtBlock and delete
|
||||
/// all of them.
|
||||
/// TODO: erase all descendents for ForStmt/IfStmt.
|
||||
void Statement::eraseFromBlock() {
|
||||
assert(getBlock() && "Statement has no block");
|
||||
getBlock()->getStatements().erase(this);
|
||||
|
@ -155,6 +154,22 @@ OperationStmt *OperationStmt::create(Identifier name,
|
|||
return stmt;
|
||||
}
|
||||
|
||||
/// Clone an existing OperationStmt.
|
||||
OperationStmt *OperationStmt::clone() const {
|
||||
SmallVector<MLValue *, 8> operands;
|
||||
SmallVector<Type *, 8> resultTypes;
|
||||
|
||||
// TODO(clattner): switch this to iterator logic.
|
||||
// Put together operands and results.
|
||||
for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
|
||||
operands.push_back(getStmtOperand(i).get());
|
||||
|
||||
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
|
||||
resultTypes.push_back(getStmtResult(i).getType());
|
||||
|
||||
return create(getName(), operands, resultTypes, getAttrs(), getContext());
|
||||
}
|
||||
|
||||
OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
|
||||
unsigned numResults,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
|
@ -205,9 +220,10 @@ void OperationStmt::dropAllReferences() {
|
|||
|
||||
ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
|
||||
AffineConstantExpr *step, MLIRContext *context)
|
||||
: Statement(Kind::For), StmtBlock(StmtBlockKind::For),
|
||||
: Statement(Kind::For),
|
||||
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
|
||||
lowerBound(lowerBound), upperBound(upperBound), step(step) {}
|
||||
StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
|
||||
upperBound(upperBound), step(step) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfStmt
|
||||
|
@ -215,6 +231,6 @@ ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
|
|||
|
||||
IfStmt::~IfStmt() {
|
||||
delete thenClause;
|
||||
if (elseClause != nullptr)
|
||||
if (elseClause)
|
||||
delete elseClause;
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/CFGFunction.h"
|
||||
#include "mlir/IR/MLFunction.h"
|
||||
|
@ -54,10 +55,13 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
|
|||
typedef llvm::iplist<Statement> StmtListType;
|
||||
bool walkPostOrder(StmtListType::iterator Start,
|
||||
StmtListType::iterator End) {
|
||||
bool hasInnerLoops = false;
|
||||
// We need to walk all elements since all innermost loops need to be
|
||||
// gathered as opposed to determining whether this list has any inner
|
||||
// loops or not.
|
||||
while (Start != End)
|
||||
if (walkPostOrder(&(*Start++)))
|
||||
return true;
|
||||
return false;
|
||||
hasInnerLoops |= walkPostOrder(&(*Start++));
|
||||
return hasInnerLoops;
|
||||
}
|
||||
|
||||
// FIXME: can't use base class method for this because that in turn would
|
||||
|
@ -73,12 +77,11 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
|
|||
}
|
||||
|
||||
bool walkIfStmtPostOrder(IfStmt *ifStmt) {
|
||||
if (walkPostOrder(ifStmt->getThenClause()->begin(),
|
||||
ifStmt->getThenClause()->end()) ||
|
||||
walkPostOrder(ifStmt->getElseClause()->begin(),
|
||||
ifStmt->getElseClause()->end()))
|
||||
return true;
|
||||
return false;
|
||||
bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
|
||||
ifStmt->getThenClause()->end());
|
||||
hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
|
||||
ifStmt->getElseClause()->end());
|
||||
return hasInnerLoops;
|
||||
}
|
||||
|
||||
bool walkOpStmt(OperationStmt *opStmt) { return false; }
|
||||
|
@ -93,17 +96,45 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
|
|||
runOnForStmt(forStmt);
|
||||
}
|
||||
|
||||
/// Unrolls this loop completely. Returns true if the unrolling happens.
|
||||
/// Replace an IV with a constant value.
|
||||
static void replaceIterator(Statement *stmt, const ForStmt &iv,
|
||||
MLValue *constVal) {
|
||||
struct ReplaceIterator : public StmtWalker<ReplaceIterator> {
|
||||
// IV to be replaced.
|
||||
const ForStmt *iv;
|
||||
// Constant to be replaced with.
|
||||
MLValue *constVal;
|
||||
|
||||
ReplaceIterator(const ForStmt &iv, MLValue *constVal)
|
||||
: iv(&iv), constVal(constVal){};
|
||||
|
||||
void visitOperationStmt(OperationStmt *os) {
|
||||
for (auto &operand : os->getStmtOperands()) {
|
||||
if (operand.get() == static_cast<const MLValue *>(iv)) {
|
||||
operand.set(constVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ReplaceIterator ri(iv, constVal);
|
||||
ri.walk(stmt);
|
||||
}
|
||||
|
||||
/// Unrolls this loop completely.
|
||||
void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
|
||||
auto lb = forStmt->getLowerBound()->getValue();
|
||||
auto ub = forStmt->getUpperBound()->getValue();
|
||||
auto step = forStmt->getStep()->getValue();
|
||||
auto trip_count = (ub - lb + 1) / step;
|
||||
|
||||
auto *block = forStmt->getBlock();
|
||||
MLFuncBuilder builder(block);
|
||||
auto *mlFunc = forStmt->Statement::findFunction();
|
||||
MLFuncBuilder funcTopBuilder(mlFunc);
|
||||
funcTopBuilder.setInsertionPointAtStart(mlFunc);
|
||||
|
||||
MLFuncBuilder builder(forStmt->getBlock());
|
||||
for (int i = 0; i < trip_count; i++) {
|
||||
auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0);
|
||||
for (auto &stmt : forStmt->getStatements()) {
|
||||
switch (stmt.getKind()) {
|
||||
case Statement::Kind::For:
|
||||
|
@ -113,16 +144,13 @@ void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
|
|||
llvm_unreachable("unrolling loops that have only operations");
|
||||
break;
|
||||
case Statement::Kind::Operation:
|
||||
auto *op = cast<OperationStmt>(&stmt);
|
||||
// TODO: clone operands and result types.
|
||||
builder.createOperation(op->getName(), /*operands*/ {},
|
||||
/*resultTypes*/ {}, op->getAttrs());
|
||||
// TODO: loop iterator parsing not yet implemented; replace loop
|
||||
// iterator uses in unrolled body appropriately.
|
||||
auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt));
|
||||
// TODO(bondhugula): only generate constants when the IV actually
|
||||
// appears in the body.
|
||||
replaceIterator(cloneOp, *forStmt, ivUnrolledVal);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
forStmt->eraseFromBlock();
|
||||
}
|
||||
|
|
|
@ -1,16 +1,64 @@
|
|||
// RUN: %S/../../mlir-opt %s -o - -unroll-innermost-loops | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: mlfunc @loops() {
|
||||
mlfunc @loops() {
|
||||
// CHECK: for %i0 = 1 to 100 step 2 {
|
||||
// CHECK-LABEL: mlfunc @loops1() {
|
||||
mlfunc @loops1() {
|
||||
// CHECK: %c0_i32 = constant 0 : i32
|
||||
// CHECK-NEXT: %c1_i32 = constant 1 : i32
|
||||
// CHECK-NEXT: %c2_i32 = constant 2 : i32
|
||||
// CHECK-NEXT: %c3_i32 = constant 3 : i32
|
||||
// CHECK-NEXT: for %i0 = 1 to 100 step 2 {
|
||||
for %i = 1 to 100 step 2 {
|
||||
// CHECK: "custom"(){value: 1} : () -> ()
|
||||
// CHECK-NEXT: "custom"(){value: 1} : () -> ()
|
||||
// CHECK-NEXT: "custom"(){value: 1} : () -> ()
|
||||
// CHECK-NEXT: "custom"(){value: 1} : () -> ()
|
||||
// CHECK: %c1_i32_0 = constant 1 : i32
|
||||
// CHECK-NEXT: %c1_i32_1 = constant 1 : i32
|
||||
// CHECK-NEXT: %c1_i32_2 = constant 1 : i32
|
||||
// CHECK-NEXT: %c1_i32_3 = constant 1 : i32
|
||||
for %j = 1 to 4 {
|
||||
"custom"(){value: 1} : () -> f32
|
||||
%x = constant 1 : i32
|
||||
}
|
||||
} // CHECK: }
|
||||
return // CHECK: return
|
||||
} // CHECK }
|
||||
|
||||
// CHECK-LABEL: mlfunc @loops2() {
|
||||
mlfunc @loops2() {
|
||||
// CHECK: %c0_i32 = constant 0 : i32
|
||||
// CHECK-NEXT: %c1_i32 = constant 1 : i32
|
||||
// CHECK-NEXT: %c2_i32 = constant 2 : i32
|
||||
// CHECK-NEXT: %c3_i32 = constant 3 : i32
|
||||
// CHECK-NEXT: %c0_i32_0 = constant 0 : i32
|
||||
// CHECK-NEXT: %c1_i32_1 = constant 1 : i32
|
||||
// CHECK-NEXT: %c2_i32_2 = constant 2 : i32
|
||||
// CHECK-NEXT: %c3_i32_3 = constant 3 : i32
|
||||
// CHECK-NEXT: for %i0 = 1 to 100 step 2 {
|
||||
for %i = 1 to 100 step 2 {
|
||||
// CHECK: %0 = affine_apply (d0) -> (d0 + 1)(%c0_i32_0)
|
||||
// CHECK-NEXT: %1 = affine_apply (d0) -> (d0 + 1)(%c1_i32_1)
|
||||
// CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c2_i32_2)
|
||||
// CHECK-NEXT: %3 = affine_apply (d0) -> (d0 + 1)(%c3_i32_3)
|
||||
for %j = 1 to 4 {
|
||||
%x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
|
||||
(affineint) -> (affineint)
|
||||
}
|
||||
} // CHECK: }
|
||||
|
||||
// CHECK: %c99 = constant 99 : affineint
|
||||
%k = "constant"(){value: 99} : () -> affineint
|
||||
// CHECK: for %i1 = 1 to 100 step 2 {
|
||||
for %m = 1 to 100 step 2 {
|
||||
// CHECK: %4 = affine_apply (d0) -> (d0 + 1)(%c0_i32)
|
||||
// CHECK-NEXT: %5 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c0_i32)[%c99]
|
||||
// CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c1_i32)
|
||||
// CHECK-NEXT: %7 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c1_i32)[%c99]
|
||||
// CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c2_i32)
|
||||
// CHECK-NEXT: %9 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c2_i32)[%c99]
|
||||
// CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
|
||||
// CHECK-NEXT: %11 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c3_i32)[%c99]
|
||||
for %n = 1 to 4 {
|
||||
%y = "affine_apply" (%n) { map: (d0) -> (d0 + 1) } :
|
||||
(affineint) -> (affineint)
|
||||
%z = "affine_apply" (%n, %k) { map: (d0) [s0] -> (d0 + s0 + 1) } :
|
||||
(affineint, affineint) -> (affineint)
|
||||
} // CHECK }
|
||||
} // CHECK }
|
||||
return // CHECK: return
|
||||
} // CHECK }
|
||||
|
|
Loading…
Reference in New Issue