Merge CFGFuncBuilder/MLFuncBuilder/FuncBuilder together into a single new

FuncBuilder class.  Also rename SSAValue.cpp to Value.cpp

This is step 12/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 227067644
This commit is contained in:
Chris Lattner 2018-12-27 15:06:22 -08:00 committed by jpienaar
parent 3f190312f8
commit 4c05f8cac6
21 changed files with 124 additions and 334 deletions

View File

@ -159,161 +159,33 @@ protected:
MLIRContext *context;
};
/// This class helps build a CFGFunction. Instructions that are created are
/// automatically inserted at an insertion point or added to the current basic
/// block.
class CFGFuncBuilder : public Builder {
/// This class helps build a Function. Instructions that are created are
/// automatically inserted at an insertion point. The builder is copyable.
class FuncBuilder : public Builder {
public:
CFGFuncBuilder(BasicBlock *block, BasicBlock::iterator insertPoint)
: Builder(block->getFunction()->getContext()),
function(block->getFunction()) {
setInsertionPoint(block, insertPoint);
/// Create an ML function builder and set the insertion point to the start of
/// the function.
FuncBuilder(Function *func) : Builder(func->getContext()), function(func) {
setInsertionPoint(&func->front(), func->front().begin());
}
CFGFuncBuilder(Instruction *insertBefore)
: CFGFuncBuilder(insertBefore->getBlock(),
BasicBlock::iterator(insertBefore)) {}
CFGFuncBuilder(BasicBlock *block)
: Builder(block->getFunction()->getContext()),
function(block->getFunction()) {
setInsertionPoint(block);
}
CFGFuncBuilder(CFGFunction *function)
: Builder(function->getContext()), function(function) {}
/// Return the function this builder is referring to.
CFGFunction *getFunction() const { return function; }
/// Reset the insertion point to no location. Creating an operation without a
/// set insertion point is an error, but this can still be useful when the
/// current insertion point a builder refers to is being removed.
void clearInsertionPoint() {
this->block = nullptr;
insertPoint = BasicBlock::iterator();
}
/// Return the block the current insertion point belongs to. Note that the
/// the insertion point is not necessarily the end of the block.
BasicBlock *getInsertionBlock() const { return block; }
/// Return the insert position as the BasicBlock iterator. The block itself
/// can be obtained by calling getInsertionBlock.
BasicBlock::iterator getInsertionPoint() const { return insertPoint; }
/// Set the insertion point to the specified location.
void setInsertionPoint(BasicBlock *block, BasicBlock::iterator insertPoint) {
assert(block->getFunction() == function &&
"can't move to a different function");
this->block = block;
this->insertPoint = insertPoint;
}
/// Set the insertion point to the specified operation.
void setInsertionPoint(Instruction *inst) {
setInsertionPoint(inst->getBlock(), BasicBlock::iterator(inst));
}
/// Set the insertion point to the end of the specified block.
void setInsertionPoint(BasicBlock *block) {
setInsertionPoint(block, block->end());
}
void insert(Instruction *opInst) {
block->getStatements().insert(insertPoint, opInst);
}
/// Add new basic block and set the insertion point to the end of it. If an
/// 'insertBefore' basic block is passed, the block will be placed before the
/// specified block. If not, the block will be appended to the end of the
/// current function.
BasicBlock *createBlock(BasicBlock *insertBefore = nullptr);
/// Create an operation given the fields represented as an OperationState.
OperationStmt *createOperation(const OperationState &state);
/// Create operation of specific op type at the current insertion point
/// without verifying to see if it is valid.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Location location, Args... args) {
OperationState state(getContext(), location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *inst = createOperation(state);
auto result = inst->dyn_cast<OpTy>();
assert(result && "Builder didn't return the right type");
return result;
}
/// Creates an operation of specific op type at the current insertion point.
/// If the result is an invalid op (the verifier hook fails), emit an error
/// and return null.
template <typename OpTy, typename... Args>
OpPointer<OpTy> createChecked(Location location, Args... args) {
OperationState state(getContext(), location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *inst = createOperation(state);
// If the Instruction we produce is valid, return it.
if (!OpTy::verifyInvariants(inst)) {
auto result = inst->dyn_cast<OpTy>();
assert(result && "Builder didn't return the right type");
return result;
}
// Otherwise, the error message got emitted. Just remove the instruction
// we made.
inst->erase();
return OpPointer<OpTy>();
}
OperationStmt *cloneOperation(const OperationStmt &srcOpInst) {
auto *op = cast<OperationStmt>(srcOpInst.clone(getContext()));
insert(op);
return op;
}
private:
CFGFunction *function;
BasicBlock *block = nullptr;
BasicBlock::iterator insertPoint;
};
/// This class helps build an MLFunction. Statements that are created are
/// automatically inserted at an insertion point or added to the current
/// statement block. The builder has only two member variables and can be passed
/// around by value.
class MLFuncBuilder : public Builder {
public:
/// Create ML function builder and set insertion point to the given statement,
/// Create a function builder and set insertion point to the given statement,
/// which will cause subsequent insertions to go right before it.
MLFuncBuilder(Statement *stmt)
// TODO: Eliminate getFunction from this.
: MLFuncBuilder(stmt->getFunction()) {
FuncBuilder(Statement *stmt) : FuncBuilder(stmt->getFunction()) {
setInsertionPoint(stmt);
}
MLFuncBuilder(StmtBlock *block)
// TODO: Eliminate getFunction from this.
: MLFuncBuilder(block->getFunction()) {
FuncBuilder(StmtBlock *block) : FuncBuilder(block->getFunction()) {
setInsertionPoint(block, block->end());
}
MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
// TODO: Eliminate getFunction from this.
: MLFuncBuilder(block->getFunction()) {
FuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
: FuncBuilder(block->getFunction()) {
setInsertionPoint(block, insertPoint);
}
/// Create an ML function builder and set the insertion point to the start of
/// the function.
MLFuncBuilder(MLFunction *func)
: Builder(func->getContext()), function(func) {
setInsertionPoint(func->getBody(), func->getBody()->begin());
}
/// Return the function this builder is referring to.
MLFunction *getFunction() const { return function; }
Function *getFunction() const { return function; }
/// Reset the insertion point to no location. Creating an operation without a
/// set insertion point is an error, but this can still be useful when the
@ -324,8 +196,6 @@ public:
}
/// Set the insertion point to the specified location.
/// Unlike CFGFuncBuilder, MLFuncBuilder allows to set insertion
/// point to a different function.
void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) {
// TODO: check that insertPoint is in this rather than some other block.
this->block = block;
@ -348,14 +218,24 @@ public:
setInsertionPoint(block, block->end());
}
/// Returns a builder for the body of a for Stmt.
static MLFuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) {
return MLFuncBuilder(forStmt->getBody(), forStmt->getBody()->end());
}
/// Return the block the current insertion point belongs to. Note that the
/// the insertion point is not necessarily the end of the block.
BasicBlock *getInsertionBlock() const { return block; }
/// Returns the current insertion point of the builder.
StmtBlock::iterator getInsertionPoint() const { return insertPoint; }
/// Add new block and set the insertion point to the end of it. If an
/// 'insertBefore' block is passed, the block will be placed before the
/// specified block. If not, the block will be appended to the end of the
/// current function.
StmtBlock *createBlock(StmtBlock *insertBefore = nullptr);
/// Returns a builder for the body of a for Stmt.
static FuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) {
return FuncBuilder(forStmt->getBody(), forStmt->getBody()->end());
}
/// Returns the current block of the builder.
StmtBlock *getBlock() const { return block; }
@ -421,84 +301,11 @@ public:
IntegerSet set);
private:
MLFunction *function;
Function *function;
StmtBlock *block = nullptr;
StmtBlock::iterator insertPoint;
};
// Wrapper around common CFGFuncBuilder and MLFuncBuilder functionality. Use
// this wrapper for interfaces where operations need to be created in either a
// CFG function or ML function.
class FuncBuilder : public Builder {
public:
FuncBuilder(CFGFuncBuilder &cfgFuncBuilder)
: Builder(cfgFuncBuilder.getContext()), builder(cfgFuncBuilder),
kind(Function::Kind::CFGFunc) {}
FuncBuilder(MLFuncBuilder &mlFuncBuilder)
: Builder(mlFuncBuilder.getContext()), builder(mlFuncBuilder),
kind(Function::Kind::MLFunc) {}
FuncBuilder(Operation *op) : Builder(op->getContext()) {
if (op->getOperationFunction()->isCFG()) {
builder = builderUnion(CFGFuncBuilder(cast<OperationInst>(op)));
kind = Function::Kind::CFGFunc;
} else {
builder = builderUnion(MLFuncBuilder(cast<OperationStmt>(op)));
kind = Function::Kind::MLFunc;
}
}
/// Creates an operation given the fields represented as an OperationState.
Operation *createOperation(const OperationState &state) {
if (kind == Function::Kind::CFGFunc)
return builder.cfg.createOperation(state);
return builder.ml.createOperation(state);
}
/// Creates operation of specific op type at the current insertion point
/// without verifying to see if it is valid.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Location location, Args... args) {
if (kind == Function::Kind::CFGFunc)
return builder.cfg.create<OpTy, Args...>(location, args...);
return builder.ml.create<OpTy, Args...>(location, args...);
}
/// Creates an operation of specific op type at the current insertion point.
/// If the result is an invalid op (the verifier hook fails), emit an error
/// and return null.
template <typename OpTy, typename... Args>
OpPointer<OpTy> createChecked(Location location, Args... args) {
if (kind == Function::Kind::CFGFunc)
return builder.cfg.createChecked<OpTy, Args...>(location, args...);
return builder.ml.createChecked<OpTy, Args...>(location, args...);
}
/// Set the insertion point to the specified operation. This requires that the
/// input operation is a Instruction when building a CFG function and a
/// OperationStmt when building a ML function.
void setInsertionPoint(Operation *op) {
if (kind == Function::Kind::CFGFunc)
builder.cfg.setInsertionPoint(cast<OperationStmt>(op));
else
builder.ml.setInsertionPoint(cast<OperationStmt>(op));
}
private:
// Wrapped builders for CFG and ML functions.
union builderUnion {
builderUnion(CFGFuncBuilder cfg) : cfg(cfg) {}
builderUnion(MLFuncBuilder ml) : ml(ml) {}
// Default initializer to allow deferring initialization of member.
builderUnion() {}
CFGFuncBuilder cfg;
MLFuncBuilder ml;
} builder;
// The type of builder in the builderUnion.
Function::Kind kind;
};
} // namespace mlir
#endif

View File

@ -36,10 +36,10 @@ class Value;
using Instruction = Statement;
using OperationInst = OperationStmt;
/// The operand of ML function statement contains a Value.
/// Operands contain a Value.
using StmtOperand = IROperandImpl<Value, Statement>;
/// This is the common base class for all values in the MLIR system,
/// This is the common base class for all SSA values in the MLIR system,
/// representing a computable value that has a type and a set of users.
///
class Value : public IRObjectWithUseList {
@ -48,7 +48,7 @@ public:
enum class Kind {
BlockArgument, // block argument
StmtResult, // statement result
ForStmt, // for statement induction variable
ForStmt, // 'for' statement induction variable
};
~Value() {}

View File

@ -32,7 +32,7 @@ class AffineMap;
class ForStmt;
class Function;
using MLFunction = Function;
class MLFuncBuilder;
class FuncBuilder;
// Values that can be used to signal success/failure. This can be implicitly
// converted to/from boolean values, with false representing success and true
@ -73,14 +73,13 @@ void promoteSingleIterationLoops(MLFunction *f);
/// Returns the lower bound of the cleanup loop when unrolling a loop
/// with the specified unroll factor.
AffineMap getCleanupLoopLowerBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder);
unsigned unrollFactor, FuncBuilder *builder);
/// Returns the upper bound of an unrolled loop when unrolling with
/// the specified trip count, stride, and unroll factor.
AffineMap getUnrolledLoopUpperBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder);
FuncBuilder *builder);
/// Skew the statements in the body of a 'for' statement with the specified
/// statement-wise shifts. The shifts are with respect to the original execution

View File

@ -32,10 +32,10 @@ namespace mlir {
/// Specialization of the pattern rewriter to ML functions.
class MLFuncLoweringRewriter : public PatternRewriter {
public:
explicit MLFuncLoweringRewriter(MLFuncBuilder *builder)
explicit MLFuncLoweringRewriter(FuncBuilder *builder)
: PatternRewriter(builder->getContext()), builder(builder) {}
MLFuncBuilder *getBuilder() { return builder; }
FuncBuilder *getBuilder() { return builder; }
Operation *createOperation(const OperationState &state) override {
auto *result = builder->createOperation(state);
@ -43,7 +43,7 @@ public:
}
private:
MLFuncBuilder *builder;
FuncBuilder *builder;
};
/// Base class for the MLFunction-wise lowering state. A pointer to the same
@ -140,7 +140,7 @@ PassResult MLPatternLoweringPass<Patterns...>::runOnMLFunction(MLFunction *f) {
detail::ListAdder<Patterns...>::addPatternsToList(&patterns, f->getContext());
auto funcWiseState = makeFuncWiseState(f);
MLFuncBuilder builder(f);
FuncBuilder builder(f);
MLFuncLoweringRewriter rewriter(&builder);
llvm::SmallVector<OperationStmt *, 0> ops;

View File

@ -173,7 +173,7 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
// Build the constraints for this region.
FlatAffineConstraints *regionCst = region->getConstraints();
MLFuncBuilder b(opStmt);
FuncBuilder b(opStmt);
auto idMap = b.getMultiDimIdentityMap(rank);
// Initialize 'accessValueMap' and compose with reachable AffineApplyOps.
@ -453,7 +453,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
// Clone src loop nest and insert it a the beginning of the statement block
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
FuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
DenseMap<const Value *, Value *> operandMap;
auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));

View File

@ -268,15 +268,15 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
}
//===----------------------------------------------------------------------===//
// CFG function elements.
// Statements.
//===----------------------------------------------------------------------===//
/// Add new basic block and set the insertion point to the end of it. If an
/// 'insertBefore' basic block is passed, the block will be placed before the
/// specified block. If not, the block will be appended to the end of the
/// current function.
BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) {
BasicBlock *b = new BasicBlock();
StmtBlock *FuncBuilder::createBlock(StmtBlock *insertBefore) {
StmtBlock *b = new StmtBlock();
// If we are supposed to insert before a specific block, do so, otherwise add
// the block to the end of the function.
@ -285,12 +285,12 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) {
else
function->push_back(b);
setInsertionPoint(b);
setInsertionPointToEnd(b);
return b;
}
/// Create an operation given the fields represented as an OperationState.
OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
OperationStmt *FuncBuilder::createOperation(const OperationState &state) {
auto *op = OperationInst::create(state.location, state.name, state.operands,
state.types, state.attributes,
state.successors, context);
@ -298,38 +298,24 @@ OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
return op;
}
//===----------------------------------------------------------------------===//
// Statements.
//===----------------------------------------------------------------------===//
/// Create an operation given the fields represented as an OperationState.
OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
auto *op = OperationStmt::create(state.location, state.name, state.operands,
state.types, state.attributes,
state.successors, context);
block->getStatements().insert(insertPoint, op);
return op;
}
ForStmt *MLFuncBuilder::createFor(Location location,
ArrayRef<Value *> lbOperands, AffineMap lbMap,
ArrayRef<Value *> ubOperands, AffineMap ubMap,
int64_t step) {
ForStmt *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step) {
auto *stmt =
ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
block->getStatements().insert(insertPoint, stmt);
return stmt;
}
ForStmt *MLFuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
int64_t step) {
ForStmt *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
int64_t step) {
auto lbMap = AffineMap::getConstantMap(lb, context);
auto ubMap = AffineMap::getConstantMap(ub, context);
return createFor(location, {}, lbMap, {}, ubMap, step);
}
IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
IntegerSet set) {
IfStmt *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
IntegerSet set) {
auto *stmt = IfStmt::create(location, operands, set);
block->getStatements().insert(insertPoint, stmt);
return stmt;

View File

@ -174,7 +174,7 @@ BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) {
// Create an unconditional branch to the new block, and move our terminator
// to the new block.
CFGFuncBuilder(this).create<BranchOp>(branchLoc, newBB);
FuncBuilder(this).create<BranchOp>(branchLoc, newBB);
return newBB;
}

View File

@ -1,4 +1,4 @@
//===- SSAValue.cpp - MLIR ValueClasses ------------===//
//===- Value.cpp - MLIR Value Classes -------------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@ -15,10 +15,9 @@
// limitations under the License.
// =============================================================================
#include "mlir/IR/Value.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Value.h"
using namespace mlir;
/// If this value is the result of an Instruction, return the instruction

View File

@ -2578,7 +2578,7 @@ private:
/// This builder intentionally shadows the builder in the base class, with a
/// more specific builder type.
CFGFuncBuilder builder;
FuncBuilder builder;
/// Get the basic block with the specified name, creating it if it doesn't
/// already exist. The location specified is the point of use, which allows
@ -2744,7 +2744,7 @@ ParseResult CFGFunctionParser::parseBasicBlock() {
// Set the insertion point to the block we want to insert new operations
// into.
builder.setInsertionPoint(block);
builder.setInsertionPointToEnd(block);
auto createOpFunc = [&](const OperationState &result) -> Operation * {
return builder.createOperation(result);
@ -2782,7 +2782,7 @@ private:
/// This builder intentionally shadows the builder in the base class, with a
/// more specific builder type.
MLFuncBuilder builder;
FuncBuilder builder;
ParseResult parseForStmt();
ParseResult parseIntConstant(int64_t &val);

View File

@ -104,7 +104,7 @@ bool ConstantFold::foldOperation(Operation *op,
// conditional branches, or anything else fancy.
PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
existingConstants.clear();
CFGFuncBuilder builder(f);
FuncBuilder builder(f);
for (auto &bb : *f) {
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
@ -141,7 +141,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
// Override the walker's operation statement visit for constant folding.
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
auto constantFactory = [&](Attribute value, Type type) -> Value * {
MLFuncBuilder builder(stmt);
FuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
};
if (!ConstantFold::foldOperation(stmt, existingConstants, constantFactory)) {

View File

@ -57,7 +57,7 @@ private:
llvm::iterator_range<Operation::result_iterator> values);
CFGFunction *cfgFunc;
CFGFuncBuilder builder;
FuncBuilder builder;
// Mapping between original Values and lowered Values.
llvm::DenseMap<const Value *, Value *> valueRemapping;
@ -224,21 +224,21 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
BasicBlock *loopBodyFirstBlock = builder.createBlock();
// At the loop insertion location, branch immediately to the loop init block.
builder.setInsertionPoint(loopInsertionPoint);
builder.setInsertionPointToEnd(loopInsertionPoint);
builder.create<BranchOp>(builder.getUnknownLoc(), loopInitBlock);
// The loop condition block has an argument for loop induction variable.
// Create it upfront and make the loop induction variable -> basic block
// argument remapping available to the following instructions. ForStatement
// is-a Value corresponding to the loop induction variable.
builder.setInsertionPoint(loopConditionBlock);
builder.setInsertionPointToEnd(loopConditionBlock);
Value *iv = loopConditionBlock->addArgument(builder.getIndexType());
valueRemapping.insert(std::make_pair(forStmt, iv));
// Recursively construct loop body region.
// Walking manually because we need custom logic before and after traversing
// the list of children.
builder.setInsertionPoint(loopBodyFirstBlock);
builder.setInsertionPointToEnd(loopBodyFirstBlock);
visitStmtBlock(forStmt->getBody());
// Builder point is currently at the last block of the loop body. Append the
@ -257,7 +257,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// Create post-loop block here so that it appears after all loop body blocks.
BasicBlock *postLoopBlock = builder.createBlock();
builder.setInsertionPoint(loopInitBlock);
builder.setInsertionPointToEnd(loopInitBlock);
// Compute loop bounds using affine_apply after remapping its operands.
auto remapOperands = [this](const Value *value) -> Value * {
return valueRemapping.lookup(value);
@ -276,7 +276,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
lowerBound);
builder.setInsertionPoint(loopConditionBlock);
builder.setInsertionPointToEnd(loopConditionBlock);
auto comparisonOp = builder.create<CmpIOp>(
forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound);
auto comparisonResult = comparisonOp->getResult();
@ -286,7 +286,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// Finally, make sure building can continue by setting the post-loop block
// (end of loop SESE region) as the insertion point.
builder.setInsertionPoint(postLoopBlock);
builder.setInsertionPointToEnd(postLoopBlock);
}
// Convert an "if" statement into a flow of basic blocks.
@ -388,7 +388,7 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
}
BasicBlock *thenBlock = builder.createBlock();
BasicBlock *elseBlock = builder.createBlock();
builder.setInsertionPoint(ifInsertionBlock);
builder.setInsertionPointToEnd(ifInsertionBlock);
// Implement short-circuit logic. For each affine expression in the 'if'
// condition, convert it into an affine map and call `affine_apply` to obtain
@ -424,17 +424,17 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
nextBlock, /*trueArgs*/ ArrayRef<Value *>(),
elseBlock,
/*falseArgs*/ ArrayRef<Value *>());
builder.setInsertionPoint(nextBlock);
builder.setInsertionPointToEnd(nextBlock);
}
ifConditionExtraBlocks.pop_back();
// Recursively traverse the 'then' block.
builder.setInsertionPoint(thenBlock);
builder.setInsertionPointToEnd(thenBlock);
visitStmtBlock(ifStmt->getThen());
BasicBlock *lastThenBlock = builder.getInsertionBlock();
// Recursively traverse the 'else' block if present.
builder.setInsertionPoint(elseBlock);
builder.setInsertionPointToEnd(elseBlock);
if (ifStmt->hasElse())
visitStmtBlock(ifStmt->getElse());
BasicBlock *lastElseBlock = builder.getInsertionBlock();
@ -443,14 +443,14 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
// 'then' and 'else' blocks, branch from end of 'then' and 'else' SESE regions
// to the continuation block.
BasicBlock *continuationBlock = builder.createBlock();
builder.setInsertionPoint(lastThenBlock);
builder.setInsertionPointToEnd(lastThenBlock);
builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
builder.setInsertionPoint(lastElseBlock);
builder.setInsertionPointToEnd(lastElseBlock);
builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
// Make sure building can continue by setting up the continuation block as the
// insertion point.
builder.setInsertionPoint(continuationBlock);
builder.setInsertionPointToEnd(continuationBlock);
}
// Entry point of the function convertor.

View File

@ -173,14 +173,14 @@ static void getMultiLevelStrides(const MemRefRegion &region,
bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
uint64_t *sizeInBytes) {
// DMAs for read regions are going to be inserted just before the for loop.
MLFuncBuilder prologue(forStmt);
FuncBuilder prologue(forStmt);
// DMAs for write regions are going to be inserted just after the for loop.
MLFuncBuilder epilogue(forStmt->getBlock(),
std::next(StmtBlock::iterator(forStmt)));
MLFuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
FuncBuilder epilogue(forStmt->getBlock(),
std::next(StmtBlock::iterator(forStmt)));
FuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
// Builder to create constants at the top level.
MLFuncBuilder top(forStmt->getFunction());
FuncBuilder top(forStmt->getFunction());
auto loc = forStmt->getLoc();
auto *memref = region.memref;

View File

@ -78,7 +78,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
assert(!origLoops.empty());
assert(origLoops.size() == tileSizes.size());
MLFuncBuilder b(origLoops[0]);
FuncBuilder b(origLoops[0]);
unsigned width = origLoops.size();
// Bounds for tile space loops.
@ -161,7 +161,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
// Add intra-tile (or point) loops.
for (unsigned i = 0; i < width; i++) {
MLFuncBuilder b(topLoop);
FuncBuilder b(topLoop);
// Loop bounds will be set later.
auto *pointLoop = b.createFor(loc, 0, 0);
pointLoop->getBody()->getStatements().splice(
@ -175,7 +175,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
// Add tile space loops;
for (unsigned i = width; i < 2 * width; i++) {
MLFuncBuilder b(topLoop);
FuncBuilder b(topLoop);
// Loop bounds will be set later.
auto *tileSpaceLoop = b.createFor(loc, 0, 0);
tileSpaceLoop->getBody()->getStatements().splice(

View File

@ -193,8 +193,8 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
DenseMap<const Value *, Value *> operandMap;
// Insert the cleanup loop right after 'forStmt'.
MLFuncBuilder builder(forStmt->getBlock(),
std::next(StmtBlock::iterator(forStmt)));
FuncBuilder builder(forStmt->getBlock(),
std::next(StmtBlock::iterator(forStmt)));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
cleanupForStmt->setLowerBoundMap(
getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder));
@ -214,8 +214,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
for (auto &subBlock : subBlocks) {
// Builder to insert unroll-jammed bodies. Insert right at the end of
// sub-block.
MLFuncBuilder builder(subBlock.first->getBlock(),
std::next(subBlock.second));
FuncBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
// Unroll and jam (appends unrollJamFactor-1 additional copies).
for (unsigned i = 1; i < unrollJamFactor; i++) {

View File

@ -64,7 +64,7 @@ using namespace mlir;
///
/// Prerequisites:
/// `a` and `b` must be of IndexType.
static mlir::Value *add(MLFuncBuilder *b, Location loc, Value *v, Value *w) {
static Value *add(FuncBuilder *b, Location loc, Value *v, Value *w) {
assert(v->getType().isa<IndexType>() && "v must be of IndexType");
assert(w->getType().isa<IndexType>() && "w must be of IndexType");
auto *context = b->getContext();
@ -114,7 +114,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// forward them or define a whole rewriting chain based on MLFunctionBuilder
// instead of Builer, the code for it would be duplicate boilerplate. As we
// go towards unifying ML and CFG functions, this separation will disappear.
MLFuncBuilder &b = *rewriter->getBuilder();
FuncBuilder &b = *rewriter->getBuilder();
// 1. First allocate the local buffer in fast memory.
// TODO(ntv): CL memory space.
@ -234,7 +234,7 @@ struct LowerVectorTransfersPass
std::unique_ptr<MLFuncGlobalLoweringState>
makeFuncWiseState(MLFunction *f) const override {
auto state = llvm::make_unique<LowerVectorTransfersState>();
auto builder = MLFuncBuilder(f);
auto builder = FuncBuilder(f);
builder.setInsertionPointToStart(f->getBody());
state->zero = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), 0);
return state;

View File

@ -247,7 +247,7 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
}
static OperationStmt *
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
instantiate(FuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap);
/// Not all Values belong to a program slice scoped within the immediately
@ -265,7 +265,7 @@ static Value *substitute(Value *v, VectorType hwVectorType,
if (it == substitutionsMap->end()) {
auto *opStmt = cast<OperationStmt>(v->getDefiningOperation());
if (opStmt->isa<ConstantOp>()) {
MLFuncBuilder b(opStmt);
FuncBuilder b(opStmt);
auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap);
auto res =
substitutionsMap->insert(std::make_pair(v, inst->getResult(0)));
@ -334,7 +334,7 @@ static Value *substitute(Value *v, VectorType hwVectorType,
/// TODO(ntv): these implementation details should be captured in a
/// vectorization trait at the op level directly.
static SmallVector<mlir::Value *, 8>
reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType,
reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
ArrayRef<unsigned> hwVectorInstance,
ArrayRef<Value *> memrefIndices) {
auto vectorShape = hwVectorType.getShape();
@ -405,7 +405,7 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) {
///
/// If the underlying substitution fails, this fails too and returns nullptr.
static OperationStmt *
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
instantiate(FuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap) {
assert(!opStmt->isa<VectorTransferReadOp>() &&
"Should call the function specialized for VectorTransferReadOp");
@ -476,8 +476,8 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
/// detailed description of the problem, see the description of
/// reindexAffineIndices.
static OperationStmt *
instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType,
ArrayRef<unsigned> hwVectorInstance,
DenseMap<const Value *, Value *> *substitutionsMap) {
SmallVector<Value *, 8> indices =
map(makePtrDynCaster<Value>(), read->getIndices());
@ -496,7 +496,7 @@ instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
/// detailed description of the problem, see the description of
/// reindexAffineIndices.
static OperationStmt *
instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write,
instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
DenseMap<const Value *, Value *> *substitutionsMap) {
SmallVector<Value *, 8> indices =
@ -543,7 +543,7 @@ static bool instantiateMaterialization(Statement *stmt,
return stmt->emitError("NYI path IfStmt");
// Create a builder here for unroll-and-jam effects.
MLFuncBuilder b(stmt);
FuncBuilder b(stmt);
auto *opStmt = cast<OperationStmt>(stmt);
if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) {
instantiate(&b, write, state->hwVectorType, state->hwVectorInstance,

View File

@ -81,7 +81,7 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
/// a replacement cannot be performed.
static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
auto *forBody = forStmt->getBody();
MLFuncBuilder bInner(forBody, forBody->begin());
FuncBuilder bInner(forBody, forBody->begin());
bInner.setInsertionPoint(forBody, forBody->begin());
// Doubles the shape with a leading dimension extent of 2.
@ -101,7 +101,7 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
auto newMemRefType = doubleShape(oldMemRefType);
// Put together alloc operands for the dynamic dimensions of the memref.
MLFuncBuilder bOuter(forStmt);
FuncBuilder bOuter(forStmt);
SmallVector<Value *, 4> allocOperands;
unsigned dynamicDimCount = 0;
for (auto dimSize : oldMemRefType.getShape()) {
@ -353,7 +353,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
LLVM_DEBUG(
// Tagging statements with shifts for debugging purposes.
if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
MLFuncBuilder b(opStmt);
FuncBuilder b(opStmt);
opStmt->setAttr(b.getIdentifier("shift"),
b.getI64IntegerAttr(shifts[s - 1]));
});

View File

@ -271,8 +271,8 @@ static void processMLFunction(MLFunction *fn,
OwningRewritePatternList &&patterns) {
class MLFuncRewriter : public WorklistRewriter {
public:
MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder)
: WorklistRewriter(driver, builder.getContext()), builder(builder) {}
MLFuncRewriter(GreedyPatternRewriteDriver &theDriver, FuncBuilder &builder)
: WorklistRewriter(theDriver, builder.getContext()), builder(builder) {}
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
@ -288,13 +288,13 @@ static void processMLFunction(MLFunction *fn,
}
private:
MLFuncBuilder &builder;
FuncBuilder &builder;
};
GreedyPatternRewriteDriver driver(std::move(patterns));
fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
MLFuncBuilder mlBuilder(fn);
FuncBuilder mlBuilder(fn);
MLFuncRewriter rewriter(driver, mlBuilder);
driver.simplifyFunction(fn, rewriter);
}
@ -303,8 +303,8 @@ static void processCFGFunction(CFGFunction *fn,
OwningRewritePatternList &&patterns) {
class CFGFuncRewriter : public WorklistRewriter {
public:
CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder)
: WorklistRewriter(driver, builder.getContext()), builder(builder) {}
CFGFuncRewriter(GreedyPatternRewriteDriver &theDriver, FuncBuilder &builder)
: WorklistRewriter(theDriver, builder.getContext()), builder(builder) {}
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
@ -320,7 +320,7 @@ static void processCFGFunction(CFGFunction *fn,
}
private:
CFGFuncBuilder &builder;
FuncBuilder &builder;
};
GreedyPatternRewriteDriver driver(std::move(patterns));
@ -329,7 +329,7 @@ static void processCFGFunction(CFGFunction *fn,
if (auto *opInst = dyn_cast<OperationStmt>(&op))
driver.addToWorklist(opInst);
CFGFuncBuilder cfgBuilder(fn);
FuncBuilder cfgBuilder(fn);
CFGFuncRewriter rewriter(driver, cfgBuilder);
driver.simplifyFunction(fn, rewriter);
}

View File

@ -40,7 +40,7 @@ using namespace mlir;
/// the trip count can't be expressed as an affine expression.
AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder) {
FuncBuilder *builder) {
auto lbMap = forStmt.getLowerBoundMap();
// Single result lower bound map only.
@ -66,7 +66,7 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
/// when the trip count can't be expressed as an affine expression.
AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder) {
FuncBuilder *builder) {
auto lbMap = forStmt.getLowerBoundMap();
// Single result lower bound map only.
@ -101,14 +101,14 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
if (!forStmt->use_empty()) {
if (forStmt->hasConstantLowerBound()) {
auto *mlFunc = forStmt->getFunction();
MLFuncBuilder topBuilder(&mlFunc->getBody()->front());
FuncBuilder topBuilder(&mlFunc->getBody()->front());
auto constOp = topBuilder.create<ConstantIndexOp>(
forStmt->getLoc(), forStmt->getConstantLowerBound());
forStmt->replaceAllUsesWith(constOp);
} else {
const AffineBound lb = forStmt->getLowerBound();
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
FuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
auto affineApplyOp = builder.create<AffineApplyOp>(
forStmt->getLoc(), lb.getMap(), lbOperands);
forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
@ -146,7 +146,7 @@ static ForStmt *
generateLoop(AffineMap lbMap, AffineMap ubMap,
const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
&stmtGroupQueue,
unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
unsigned offset, ForStmt *srcForStmt, FuncBuilder *b) {
SmallVector<Value *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
SmallVector<Value *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
@ -167,7 +167,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
// Generate the remapping if the shift is not zero: remappedIV = newIV -
// shift.
if (!srcForStmt->use_empty() && shift != 0) {
auto b = MLFuncBuilder::getForStmtBodyBuilder(loopChunk);
auto b = FuncBuilder::getForStmtBodyBuilder(loopChunk);
auto *ivRemap = b.create<AffineApplyOp>(
srcForStmt->getLoc(),
b.getSingleDimShiftAffineMap(-static_cast<int64_t>(
@ -261,7 +261,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
auto origLbMap = forStmt->getLowerBoundMap();
uint64_t lbShift = 0;
MLFuncBuilder b(forStmt);
FuncBuilder b(forStmt);
for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
// If nothing is shifted by d, continue.
if (sortedStmtGroups[d].empty())
@ -379,7 +379,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
DenseMap<const Value *, Value *> operandMap;
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
FuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
assert(clLbMap &&
@ -404,7 +404,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
// Builder to insert unrolled bodies right after the last statement in the
// body of 'forStmt'.
MLFuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
FuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
// Keep a pointer to the last statement in the original block so that we know
// what to clone (since we are doing this in-place).

View File

@ -124,7 +124,7 @@ bool mlir::expandAffineApply(AffineApplyOp *op) {
if (!op)
return true;
FuncBuilder builder(op->getOperation());
FuncBuilder builder(cast<OperationStmt>(op->getOperation()));
auto affineMap = op->getAffineMap();
for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) {
Value *expanded = expandAffineExpr(&builder, numberedExpr.value(), op);

View File

@ -842,7 +842,7 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
makePermutationMap(opStmt, state->strategy->loopToVectorDim);
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
MLFuncBuilder b(opStmt);
FuncBuilder b(opStmt);
auto transfer = b.create<VectorTransferReadOp>(
opStmt->getLoc(), vectorType, memoryOp->getMemRef(),
map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap);
@ -970,7 +970,7 @@ static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
!VectorType::isValidElementType(constant.getType())) {
return nullptr;
}
MLFuncBuilder b(stmt);
FuncBuilder b(stmt);
Location loc = stmt->getLoc();
auto vectorType = type.cast<VectorType>();
auto attr = SplatElementsAttr::get(vectorType, constant.getValue());
@ -1068,7 +1068,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt,
/// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized.
/// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
/// do one-off logic here; ideally it would be TableGen'd.
static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b,
static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b,
OperationStmt *opStmt,
VectorizationState *state) {
// Sanity checks.
@ -1084,7 +1084,7 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b,
auto *value = store->getValueToStore();
auto *vectorValue = vectorizeOperand(value, opStmt, state);
auto indices = map(makePtrDynCaster<Value>(), store->getIndices());
MLFuncBuilder b(opStmt);
FuncBuilder b(opStmt);
auto permutationMap =
makePermutationMap(opStmt, state->strategy->loopToVectorDim);
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
@ -1159,7 +1159,7 @@ static bool vectorizeOperations(VectorizationState *state) {
// 2. Create vectorized form of the statement.
// Insert it just before stmt, on success register stmt as replaced.
MLFuncBuilder b(stmt);
FuncBuilder b(stmt);
auto *vectorizedStmt = vectorizeOneOperationStmt(&b, stmt, state);
if (!vectorizedStmt) {
return true;
@ -1200,7 +1200,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
continue;
}
MLFuncBuilder builder(loop); // builder to insert in place of loop
FuncBuilder builder(loop); // builder to insert in place of loop
DenseMap<const Value *, Value *> nomap;
ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap));
auto fail = doVectorize(m, &state);
@ -1244,7 +1244,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
if (fail) {
return;
}
MLFuncBuilder b(stmt);
FuncBuilder b(stmt);
auto *res = vectorizeOneOperationStmt(&b, stmt, &state);
if (res == nullptr) {
fail = true;