forked from OSchip/llvm-project
Merge CFGFuncBuilder/MLFuncBuilder/FuncBuilder together into a single new
FuncBuilder class. Also rename SSAValue.cpp to Value.cpp This is step 12/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 227067644
This commit is contained in:
parent
3f190312f8
commit
4c05f8cac6
|
@ -159,161 +159,33 @@ protected:
|
|||
MLIRContext *context;
|
||||
};
|
||||
|
||||
/// This class helps build a CFGFunction. Instructions that are created are
|
||||
/// automatically inserted at an insertion point or added to the current basic
|
||||
/// block.
|
||||
class CFGFuncBuilder : public Builder {
|
||||
/// This class helps build a Function. Instructions that are created are
|
||||
/// automatically inserted at an insertion point. The builder is copyable.
|
||||
class FuncBuilder : public Builder {
|
||||
public:
|
||||
CFGFuncBuilder(BasicBlock *block, BasicBlock::iterator insertPoint)
|
||||
: Builder(block->getFunction()->getContext()),
|
||||
function(block->getFunction()) {
|
||||
setInsertionPoint(block, insertPoint);
|
||||
/// Create an ML function builder and set the insertion point to the start of
|
||||
/// the function.
|
||||
FuncBuilder(Function *func) : Builder(func->getContext()), function(func) {
|
||||
setInsertionPoint(&func->front(), func->front().begin());
|
||||
}
|
||||
|
||||
CFGFuncBuilder(Instruction *insertBefore)
|
||||
: CFGFuncBuilder(insertBefore->getBlock(),
|
||||
BasicBlock::iterator(insertBefore)) {}
|
||||
|
||||
CFGFuncBuilder(BasicBlock *block)
|
||||
: Builder(block->getFunction()->getContext()),
|
||||
function(block->getFunction()) {
|
||||
setInsertionPoint(block);
|
||||
}
|
||||
|
||||
CFGFuncBuilder(CFGFunction *function)
|
||||
: Builder(function->getContext()), function(function) {}
|
||||
|
||||
/// Return the function this builder is referring to.
|
||||
CFGFunction *getFunction() const { return function; }
|
||||
|
||||
/// Reset the insertion point to no location. Creating an operation without a
|
||||
/// set insertion point is an error, but this can still be useful when the
|
||||
/// current insertion point a builder refers to is being removed.
|
||||
void clearInsertionPoint() {
|
||||
this->block = nullptr;
|
||||
insertPoint = BasicBlock::iterator();
|
||||
}
|
||||
|
||||
/// Return the block the current insertion point belongs to. Note that the
|
||||
/// the insertion point is not necessarily the end of the block.
|
||||
BasicBlock *getInsertionBlock() const { return block; }
|
||||
|
||||
/// Return the insert position as the BasicBlock iterator. The block itself
|
||||
/// can be obtained by calling getInsertionBlock.
|
||||
BasicBlock::iterator getInsertionPoint() const { return insertPoint; }
|
||||
|
||||
/// Set the insertion point to the specified location.
|
||||
void setInsertionPoint(BasicBlock *block, BasicBlock::iterator insertPoint) {
|
||||
assert(block->getFunction() == function &&
|
||||
"can't move to a different function");
|
||||
this->block = block;
|
||||
this->insertPoint = insertPoint;
|
||||
}
|
||||
|
||||
/// Set the insertion point to the specified operation.
|
||||
void setInsertionPoint(Instruction *inst) {
|
||||
setInsertionPoint(inst->getBlock(), BasicBlock::iterator(inst));
|
||||
}
|
||||
|
||||
/// Set the insertion point to the end of the specified block.
|
||||
void setInsertionPoint(BasicBlock *block) {
|
||||
setInsertionPoint(block, block->end());
|
||||
}
|
||||
|
||||
void insert(Instruction *opInst) {
|
||||
block->getStatements().insert(insertPoint, opInst);
|
||||
}
|
||||
|
||||
/// Add new basic block and set the insertion point to the end of it. If an
|
||||
/// 'insertBefore' basic block is passed, the block will be placed before the
|
||||
/// specified block. If not, the block will be appended to the end of the
|
||||
/// current function.
|
||||
BasicBlock *createBlock(BasicBlock *insertBefore = nullptr);
|
||||
|
||||
/// Create an operation given the fields represented as an OperationState.
|
||||
OperationStmt *createOperation(const OperationState &state);
|
||||
|
||||
/// Create operation of specific op type at the current insertion point
|
||||
/// without verifying to see if it is valid.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpPointer<OpTy> create(Location location, Args... args) {
|
||||
OperationState state(getContext(), location, OpTy::getOperationName());
|
||||
OpTy::build(this, &state, args...);
|
||||
auto *inst = createOperation(state);
|
||||
auto result = inst->dyn_cast<OpTy>();
|
||||
assert(result && "Builder didn't return the right type");
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Creates an operation of specific op type at the current insertion point.
|
||||
/// If the result is an invalid op (the verifier hook fails), emit an error
|
||||
/// and return null.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpPointer<OpTy> createChecked(Location location, Args... args) {
|
||||
OperationState state(getContext(), location, OpTy::getOperationName());
|
||||
OpTy::build(this, &state, args...);
|
||||
auto *inst = createOperation(state);
|
||||
|
||||
// If the Instruction we produce is valid, return it.
|
||||
if (!OpTy::verifyInvariants(inst)) {
|
||||
auto result = inst->dyn_cast<OpTy>();
|
||||
assert(result && "Builder didn't return the right type");
|
||||
return result;
|
||||
}
|
||||
|
||||
// Otherwise, the error message got emitted. Just remove the instruction
|
||||
// we made.
|
||||
inst->erase();
|
||||
return OpPointer<OpTy>();
|
||||
}
|
||||
|
||||
OperationStmt *cloneOperation(const OperationStmt &srcOpInst) {
|
||||
auto *op = cast<OperationStmt>(srcOpInst.clone(getContext()));
|
||||
insert(op);
|
||||
return op;
|
||||
}
|
||||
|
||||
private:
|
||||
CFGFunction *function;
|
||||
BasicBlock *block = nullptr;
|
||||
BasicBlock::iterator insertPoint;
|
||||
};
|
||||
|
||||
/// This class helps build an MLFunction. Statements that are created are
|
||||
/// automatically inserted at an insertion point or added to the current
|
||||
/// statement block. The builder has only two member variables and can be passed
|
||||
/// around by value.
|
||||
class MLFuncBuilder : public Builder {
|
||||
public:
|
||||
/// Create ML function builder and set insertion point to the given statement,
|
||||
/// Create a function builder and set insertion point to the given statement,
|
||||
/// which will cause subsequent insertions to go right before it.
|
||||
MLFuncBuilder(Statement *stmt)
|
||||
// TODO: Eliminate getFunction from this.
|
||||
: MLFuncBuilder(stmt->getFunction()) {
|
||||
FuncBuilder(Statement *stmt) : FuncBuilder(stmt->getFunction()) {
|
||||
setInsertionPoint(stmt);
|
||||
}
|
||||
|
||||
MLFuncBuilder(StmtBlock *block)
|
||||
// TODO: Eliminate getFunction from this.
|
||||
: MLFuncBuilder(block->getFunction()) {
|
||||
FuncBuilder(StmtBlock *block) : FuncBuilder(block->getFunction()) {
|
||||
setInsertionPoint(block, block->end());
|
||||
}
|
||||
|
||||
MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
|
||||
// TODO: Eliminate getFunction from this.
|
||||
: MLFuncBuilder(block->getFunction()) {
|
||||
FuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
|
||||
: FuncBuilder(block->getFunction()) {
|
||||
setInsertionPoint(block, insertPoint);
|
||||
}
|
||||
|
||||
/// Create an ML function builder and set the insertion point to the start of
|
||||
/// the function.
|
||||
MLFuncBuilder(MLFunction *func)
|
||||
: Builder(func->getContext()), function(func) {
|
||||
setInsertionPoint(func->getBody(), func->getBody()->begin());
|
||||
}
|
||||
|
||||
/// Return the function this builder is referring to.
|
||||
MLFunction *getFunction() const { return function; }
|
||||
Function *getFunction() const { return function; }
|
||||
|
||||
/// Reset the insertion point to no location. Creating an operation without a
|
||||
/// set insertion point is an error, but this can still be useful when the
|
||||
|
@ -324,8 +196,6 @@ public:
|
|||
}
|
||||
|
||||
/// Set the insertion point to the specified location.
|
||||
/// Unlike CFGFuncBuilder, MLFuncBuilder allows to set insertion
|
||||
/// point to a different function.
|
||||
void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) {
|
||||
// TODO: check that insertPoint is in this rather than some other block.
|
||||
this->block = block;
|
||||
|
@ -348,14 +218,24 @@ public:
|
|||
setInsertionPoint(block, block->end());
|
||||
}
|
||||
|
||||
/// Returns a builder for the body of a for Stmt.
|
||||
static MLFuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) {
|
||||
return MLFuncBuilder(forStmt->getBody(), forStmt->getBody()->end());
|
||||
}
|
||||
/// Return the block the current insertion point belongs to. Note that the
|
||||
/// the insertion point is not necessarily the end of the block.
|
||||
BasicBlock *getInsertionBlock() const { return block; }
|
||||
|
||||
/// Returns the current insertion point of the builder.
|
||||
StmtBlock::iterator getInsertionPoint() const { return insertPoint; }
|
||||
|
||||
/// Add new block and set the insertion point to the end of it. If an
|
||||
/// 'insertBefore' block is passed, the block will be placed before the
|
||||
/// specified block. If not, the block will be appended to the end of the
|
||||
/// current function.
|
||||
StmtBlock *createBlock(StmtBlock *insertBefore = nullptr);
|
||||
|
||||
/// Returns a builder for the body of a for Stmt.
|
||||
static FuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) {
|
||||
return FuncBuilder(forStmt->getBody(), forStmt->getBody()->end());
|
||||
}
|
||||
|
||||
/// Returns the current block of the builder.
|
||||
StmtBlock *getBlock() const { return block; }
|
||||
|
||||
|
@ -421,84 +301,11 @@ public:
|
|||
IntegerSet set);
|
||||
|
||||
private:
|
||||
MLFunction *function;
|
||||
Function *function;
|
||||
StmtBlock *block = nullptr;
|
||||
StmtBlock::iterator insertPoint;
|
||||
};
|
||||
|
||||
// Wrapper around common CFGFuncBuilder and MLFuncBuilder functionality. Use
|
||||
// this wrapper for interfaces where operations need to be created in either a
|
||||
// CFG function or ML function.
|
||||
class FuncBuilder : public Builder {
|
||||
public:
|
||||
FuncBuilder(CFGFuncBuilder &cfgFuncBuilder)
|
||||
: Builder(cfgFuncBuilder.getContext()), builder(cfgFuncBuilder),
|
||||
kind(Function::Kind::CFGFunc) {}
|
||||
FuncBuilder(MLFuncBuilder &mlFuncBuilder)
|
||||
: Builder(mlFuncBuilder.getContext()), builder(mlFuncBuilder),
|
||||
kind(Function::Kind::MLFunc) {}
|
||||
FuncBuilder(Operation *op) : Builder(op->getContext()) {
|
||||
if (op->getOperationFunction()->isCFG()) {
|
||||
builder = builderUnion(CFGFuncBuilder(cast<OperationInst>(op)));
|
||||
kind = Function::Kind::CFGFunc;
|
||||
} else {
|
||||
builder = builderUnion(MLFuncBuilder(cast<OperationStmt>(op)));
|
||||
kind = Function::Kind::MLFunc;
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an operation given the fields represented as an OperationState.
|
||||
Operation *createOperation(const OperationState &state) {
|
||||
if (kind == Function::Kind::CFGFunc)
|
||||
return builder.cfg.createOperation(state);
|
||||
return builder.ml.createOperation(state);
|
||||
}
|
||||
|
||||
/// Creates operation of specific op type at the current insertion point
|
||||
/// without verifying to see if it is valid.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpPointer<OpTy> create(Location location, Args... args) {
|
||||
if (kind == Function::Kind::CFGFunc)
|
||||
return builder.cfg.create<OpTy, Args...>(location, args...);
|
||||
return builder.ml.create<OpTy, Args...>(location, args...);
|
||||
}
|
||||
|
||||
/// Creates an operation of specific op type at the current insertion point.
|
||||
/// If the result is an invalid op (the verifier hook fails), emit an error
|
||||
/// and return null.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpPointer<OpTy> createChecked(Location location, Args... args) {
|
||||
if (kind == Function::Kind::CFGFunc)
|
||||
return builder.cfg.createChecked<OpTy, Args...>(location, args...);
|
||||
return builder.ml.createChecked<OpTy, Args...>(location, args...);
|
||||
}
|
||||
|
||||
/// Set the insertion point to the specified operation. This requires that the
|
||||
/// input operation is a Instruction when building a CFG function and a
|
||||
/// OperationStmt when building a ML function.
|
||||
void setInsertionPoint(Operation *op) {
|
||||
if (kind == Function::Kind::CFGFunc)
|
||||
builder.cfg.setInsertionPoint(cast<OperationStmt>(op));
|
||||
else
|
||||
builder.ml.setInsertionPoint(cast<OperationStmt>(op));
|
||||
}
|
||||
|
||||
private:
|
||||
// Wrapped builders for CFG and ML functions.
|
||||
union builderUnion {
|
||||
builderUnion(CFGFuncBuilder cfg) : cfg(cfg) {}
|
||||
builderUnion(MLFuncBuilder ml) : ml(ml) {}
|
||||
// Default initializer to allow deferring initialization of member.
|
||||
builderUnion() {}
|
||||
|
||||
CFGFuncBuilder cfg;
|
||||
MLFuncBuilder ml;
|
||||
} builder;
|
||||
|
||||
// The type of builder in the builderUnion.
|
||||
Function::Kind kind;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
|
|
|
@ -36,10 +36,10 @@ class Value;
|
|||
using Instruction = Statement;
|
||||
using OperationInst = OperationStmt;
|
||||
|
||||
/// The operand of ML function statement contains a Value.
|
||||
/// Operands contain a Value.
|
||||
using StmtOperand = IROperandImpl<Value, Statement>;
|
||||
|
||||
/// This is the common base class for all values in the MLIR system,
|
||||
/// This is the common base class for all SSA values in the MLIR system,
|
||||
/// representing a computable value that has a type and a set of users.
|
||||
///
|
||||
class Value : public IRObjectWithUseList {
|
||||
|
@ -48,7 +48,7 @@ public:
|
|||
enum class Kind {
|
||||
BlockArgument, // block argument
|
||||
StmtResult, // statement result
|
||||
ForStmt, // for statement induction variable
|
||||
ForStmt, // 'for' statement induction variable
|
||||
};
|
||||
|
||||
~Value() {}
|
||||
|
|
|
@ -32,7 +32,7 @@ class AffineMap;
|
|||
class ForStmt;
|
||||
class Function;
|
||||
using MLFunction = Function;
|
||||
class MLFuncBuilder;
|
||||
class FuncBuilder;
|
||||
|
||||
// Values that can be used to signal success/failure. This can be implicitly
|
||||
// converted to/from boolean values, with false representing success and true
|
||||
|
@ -73,14 +73,13 @@ void promoteSingleIterationLoops(MLFunction *f);
|
|||
/// Returns the lower bound of the cleanup loop when unrolling a loop
|
||||
/// with the specified unroll factor.
|
||||
AffineMap getCleanupLoopLowerBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder);
|
||||
unsigned unrollFactor, FuncBuilder *builder);
|
||||
|
||||
/// Returns the upper bound of an unrolled loop when unrolling with
|
||||
/// the specified trip count, stride, and unroll factor.
|
||||
AffineMap getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder);
|
||||
FuncBuilder *builder);
|
||||
|
||||
/// Skew the statements in the body of a 'for' statement with the specified
|
||||
/// statement-wise shifts. The shifts are with respect to the original execution
|
||||
|
|
|
@ -32,10 +32,10 @@ namespace mlir {
|
|||
/// Specialization of the pattern rewriter to ML functions.
|
||||
class MLFuncLoweringRewriter : public PatternRewriter {
|
||||
public:
|
||||
explicit MLFuncLoweringRewriter(MLFuncBuilder *builder)
|
||||
explicit MLFuncLoweringRewriter(FuncBuilder *builder)
|
||||
: PatternRewriter(builder->getContext()), builder(builder) {}
|
||||
|
||||
MLFuncBuilder *getBuilder() { return builder; }
|
||||
FuncBuilder *getBuilder() { return builder; }
|
||||
|
||||
Operation *createOperation(const OperationState &state) override {
|
||||
auto *result = builder->createOperation(state);
|
||||
|
@ -43,7 +43,7 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
MLFuncBuilder *builder;
|
||||
FuncBuilder *builder;
|
||||
};
|
||||
|
||||
/// Base class for the MLFunction-wise lowering state. A pointer to the same
|
||||
|
@ -140,7 +140,7 @@ PassResult MLPatternLoweringPass<Patterns...>::runOnMLFunction(MLFunction *f) {
|
|||
detail::ListAdder<Patterns...>::addPatternsToList(&patterns, f->getContext());
|
||||
auto funcWiseState = makeFuncWiseState(f);
|
||||
|
||||
MLFuncBuilder builder(f);
|
||||
FuncBuilder builder(f);
|
||||
MLFuncLoweringRewriter rewriter(&builder);
|
||||
|
||||
llvm::SmallVector<OperationStmt *, 0> ops;
|
||||
|
|
|
@ -173,7 +173,7 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
|
|||
// Build the constraints for this region.
|
||||
FlatAffineConstraints *regionCst = region->getConstraints();
|
||||
|
||||
MLFuncBuilder b(opStmt);
|
||||
FuncBuilder b(opStmt);
|
||||
auto idMap = b.getMultiDimIdentityMap(rank);
|
||||
|
||||
// Initialize 'accessValueMap' and compose with reachable AffineApplyOps.
|
||||
|
@ -453,7 +453,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
|||
// Clone src loop nest and insert it a the beginning of the statement block
|
||||
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
|
||||
auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
|
||||
MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
|
||||
FuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
|
||||
|
||||
|
|
|
@ -268,15 +268,15 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CFG function elements.
|
||||
// Statements.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Add new basic block and set the insertion point to the end of it. If an
|
||||
/// 'insertBefore' basic block is passed, the block will be placed before the
|
||||
/// specified block. If not, the block will be appended to the end of the
|
||||
/// current function.
|
||||
BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) {
|
||||
BasicBlock *b = new BasicBlock();
|
||||
StmtBlock *FuncBuilder::createBlock(StmtBlock *insertBefore) {
|
||||
StmtBlock *b = new StmtBlock();
|
||||
|
||||
// If we are supposed to insert before a specific block, do so, otherwise add
|
||||
// the block to the end of the function.
|
||||
|
@ -285,12 +285,12 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) {
|
|||
else
|
||||
function->push_back(b);
|
||||
|
||||
setInsertionPoint(b);
|
||||
setInsertionPointToEnd(b);
|
||||
return b;
|
||||
}
|
||||
|
||||
/// Create an operation given the fields represented as an OperationState.
|
||||
OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
|
||||
OperationStmt *FuncBuilder::createOperation(const OperationState &state) {
|
||||
auto *op = OperationInst::create(state.location, state.name, state.operands,
|
||||
state.types, state.attributes,
|
||||
state.successors, context);
|
||||
|
@ -298,38 +298,24 @@ OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
|
|||
return op;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Statements.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Create an operation given the fields represented as an OperationState.
|
||||
OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
|
||||
auto *op = OperationStmt::create(state.location, state.name, state.operands,
|
||||
state.types, state.attributes,
|
||||
state.successors, context);
|
||||
block->getStatements().insert(insertPoint, op);
|
||||
return op;
|
||||
}
|
||||
|
||||
ForStmt *MLFuncBuilder::createFor(Location location,
|
||||
ArrayRef<Value *> lbOperands, AffineMap lbMap,
|
||||
ArrayRef<Value *> ubOperands, AffineMap ubMap,
|
||||
int64_t step) {
|
||||
ForStmt *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<Value *> ubOperands,
|
||||
AffineMap ubMap, int64_t step) {
|
||||
auto *stmt =
|
||||
ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
|
||||
block->getStatements().insert(insertPoint, stmt);
|
||||
return stmt;
|
||||
}
|
||||
|
||||
ForStmt *MLFuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
|
||||
int64_t step) {
|
||||
ForStmt *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
|
||||
int64_t step) {
|
||||
auto lbMap = AffineMap::getConstantMap(lb, context);
|
||||
auto ubMap = AffineMap::getConstantMap(ub, context);
|
||||
return createFor(location, {}, lbMap, {}, ubMap, step);
|
||||
}
|
||||
|
||||
IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set) {
|
||||
IfStmt *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set) {
|
||||
auto *stmt = IfStmt::create(location, operands, set);
|
||||
block->getStatements().insert(insertPoint, stmt);
|
||||
return stmt;
|
||||
|
|
|
@ -174,7 +174,7 @@ BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) {
|
|||
|
||||
// Create an unconditional branch to the new block, and move our terminator
|
||||
// to the new block.
|
||||
CFGFuncBuilder(this).create<BranchOp>(branchLoc, newBB);
|
||||
FuncBuilder(this).create<BranchOp>(branchLoc, newBB);
|
||||
return newBB;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- SSAValue.cpp - MLIR ValueClasses ------------===//
|
||||
//===- Value.cpp - MLIR Value Classes -------------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -15,10 +15,9 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// If this value is the result of an Instruction, return the instruction
|
|
@ -2578,7 +2578,7 @@ private:
|
|||
|
||||
/// This builder intentionally shadows the builder in the base class, with a
|
||||
/// more specific builder type.
|
||||
CFGFuncBuilder builder;
|
||||
FuncBuilder builder;
|
||||
|
||||
/// Get the basic block with the specified name, creating it if it doesn't
|
||||
/// already exist. The location specified is the point of use, which allows
|
||||
|
@ -2744,7 +2744,7 @@ ParseResult CFGFunctionParser::parseBasicBlock() {
|
|||
|
||||
// Set the insertion point to the block we want to insert new operations
|
||||
// into.
|
||||
builder.setInsertionPoint(block);
|
||||
builder.setInsertionPointToEnd(block);
|
||||
|
||||
auto createOpFunc = [&](const OperationState &result) -> Operation * {
|
||||
return builder.createOperation(result);
|
||||
|
@ -2782,7 +2782,7 @@ private:
|
|||
|
||||
/// This builder intentionally shadows the builder in the base class, with a
|
||||
/// more specific builder type.
|
||||
MLFuncBuilder builder;
|
||||
FuncBuilder builder;
|
||||
|
||||
ParseResult parseForStmt();
|
||||
ParseResult parseIntConstant(int64_t &val);
|
||||
|
|
|
@ -104,7 +104,7 @@ bool ConstantFold::foldOperation(Operation *op,
|
|||
// conditional branches, or anything else fancy.
|
||||
PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
||||
existingConstants.clear();
|
||||
CFGFuncBuilder builder(f);
|
||||
FuncBuilder builder(f);
|
||||
|
||||
for (auto &bb : *f) {
|
||||
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
|
||||
|
@ -141,7 +141,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
|||
// Override the walker's operation statement visit for constant folding.
|
||||
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
|
||||
auto constantFactory = [&](Attribute value, Type type) -> Value * {
|
||||
MLFuncBuilder builder(stmt);
|
||||
FuncBuilder builder(stmt);
|
||||
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
|
||||
};
|
||||
if (!ConstantFold::foldOperation(stmt, existingConstants, constantFactory)) {
|
||||
|
|
|
@ -57,7 +57,7 @@ private:
|
|||
llvm::iterator_range<Operation::result_iterator> values);
|
||||
|
||||
CFGFunction *cfgFunc;
|
||||
CFGFuncBuilder builder;
|
||||
FuncBuilder builder;
|
||||
|
||||
// Mapping between original Values and lowered Values.
|
||||
llvm::DenseMap<const Value *, Value *> valueRemapping;
|
||||
|
@ -224,21 +224,21 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
BasicBlock *loopBodyFirstBlock = builder.createBlock();
|
||||
|
||||
// At the loop insertion location, branch immediately to the loop init block.
|
||||
builder.setInsertionPoint(loopInsertionPoint);
|
||||
builder.setInsertionPointToEnd(loopInsertionPoint);
|
||||
builder.create<BranchOp>(builder.getUnknownLoc(), loopInitBlock);
|
||||
|
||||
// The loop condition block has an argument for loop induction variable.
|
||||
// Create it upfront and make the loop induction variable -> basic block
|
||||
// argument remapping available to the following instructions. ForStatement
|
||||
// is-a Value corresponding to the loop induction variable.
|
||||
builder.setInsertionPoint(loopConditionBlock);
|
||||
builder.setInsertionPointToEnd(loopConditionBlock);
|
||||
Value *iv = loopConditionBlock->addArgument(builder.getIndexType());
|
||||
valueRemapping.insert(std::make_pair(forStmt, iv));
|
||||
|
||||
// Recursively construct loop body region.
|
||||
// Walking manually because we need custom logic before and after traversing
|
||||
// the list of children.
|
||||
builder.setInsertionPoint(loopBodyFirstBlock);
|
||||
builder.setInsertionPointToEnd(loopBodyFirstBlock);
|
||||
visitStmtBlock(forStmt->getBody());
|
||||
|
||||
// Builder point is currently at the last block of the loop body. Append the
|
||||
|
@ -257,7 +257,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
// Create post-loop block here so that it appears after all loop body blocks.
|
||||
BasicBlock *postLoopBlock = builder.createBlock();
|
||||
|
||||
builder.setInsertionPoint(loopInitBlock);
|
||||
builder.setInsertionPointToEnd(loopInitBlock);
|
||||
// Compute loop bounds using affine_apply after remapping its operands.
|
||||
auto remapOperands = [this](const Value *value) -> Value * {
|
||||
return valueRemapping.lookup(value);
|
||||
|
@ -276,7 +276,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
|
||||
lowerBound);
|
||||
|
||||
builder.setInsertionPoint(loopConditionBlock);
|
||||
builder.setInsertionPointToEnd(loopConditionBlock);
|
||||
auto comparisonOp = builder.create<CmpIOp>(
|
||||
forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound);
|
||||
auto comparisonResult = comparisonOp->getResult();
|
||||
|
@ -286,7 +286,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
|
||||
// Finally, make sure building can continue by setting the post-loop block
|
||||
// (end of loop SESE region) as the insertion point.
|
||||
builder.setInsertionPoint(postLoopBlock);
|
||||
builder.setInsertionPointToEnd(postLoopBlock);
|
||||
}
|
||||
|
||||
// Convert an "if" statement into a flow of basic blocks.
|
||||
|
@ -388,7 +388,7 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
|
|||
}
|
||||
BasicBlock *thenBlock = builder.createBlock();
|
||||
BasicBlock *elseBlock = builder.createBlock();
|
||||
builder.setInsertionPoint(ifInsertionBlock);
|
||||
builder.setInsertionPointToEnd(ifInsertionBlock);
|
||||
|
||||
// Implement short-circuit logic. For each affine expression in the 'if'
|
||||
// condition, convert it into an affine map and call `affine_apply` to obtain
|
||||
|
@ -424,17 +424,17 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
|
|||
nextBlock, /*trueArgs*/ ArrayRef<Value *>(),
|
||||
elseBlock,
|
||||
/*falseArgs*/ ArrayRef<Value *>());
|
||||
builder.setInsertionPoint(nextBlock);
|
||||
builder.setInsertionPointToEnd(nextBlock);
|
||||
}
|
||||
ifConditionExtraBlocks.pop_back();
|
||||
|
||||
// Recursively traverse the 'then' block.
|
||||
builder.setInsertionPoint(thenBlock);
|
||||
builder.setInsertionPointToEnd(thenBlock);
|
||||
visitStmtBlock(ifStmt->getThen());
|
||||
BasicBlock *lastThenBlock = builder.getInsertionBlock();
|
||||
|
||||
// Recursively traverse the 'else' block if present.
|
||||
builder.setInsertionPoint(elseBlock);
|
||||
builder.setInsertionPointToEnd(elseBlock);
|
||||
if (ifStmt->hasElse())
|
||||
visitStmtBlock(ifStmt->getElse());
|
||||
BasicBlock *lastElseBlock = builder.getInsertionBlock();
|
||||
|
@ -443,14 +443,14 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
|
|||
// 'then' and 'else' blocks, branch from end of 'then' and 'else' SESE regions
|
||||
// to the continuation block.
|
||||
BasicBlock *continuationBlock = builder.createBlock();
|
||||
builder.setInsertionPoint(lastThenBlock);
|
||||
builder.setInsertionPointToEnd(lastThenBlock);
|
||||
builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
|
||||
builder.setInsertionPoint(lastElseBlock);
|
||||
builder.setInsertionPointToEnd(lastElseBlock);
|
||||
builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
|
||||
|
||||
// Make sure building can continue by setting up the continuation block as the
|
||||
// insertion point.
|
||||
builder.setInsertionPoint(continuationBlock);
|
||||
builder.setInsertionPointToEnd(continuationBlock);
|
||||
}
|
||||
|
||||
// Entry point of the function convertor.
|
||||
|
|
|
@ -173,14 +173,14 @@ static void getMultiLevelStrides(const MemRefRegion ®ion,
|
|||
bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
||||
uint64_t *sizeInBytes) {
|
||||
// DMAs for read regions are going to be inserted just before the for loop.
|
||||
MLFuncBuilder prologue(forStmt);
|
||||
FuncBuilder prologue(forStmt);
|
||||
// DMAs for write regions are going to be inserted just after the for loop.
|
||||
MLFuncBuilder epilogue(forStmt->getBlock(),
|
||||
std::next(StmtBlock::iterator(forStmt)));
|
||||
MLFuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
|
||||
FuncBuilder epilogue(forStmt->getBlock(),
|
||||
std::next(StmtBlock::iterator(forStmt)));
|
||||
FuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
|
||||
|
||||
// Builder to create constants at the top level.
|
||||
MLFuncBuilder top(forStmt->getFunction());
|
||||
FuncBuilder top(forStmt->getFunction());
|
||||
|
||||
auto loc = forStmt->getLoc();
|
||||
auto *memref = region.memref;
|
||||
|
|
|
@ -78,7 +78,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
|
|||
assert(!origLoops.empty());
|
||||
assert(origLoops.size() == tileSizes.size());
|
||||
|
||||
MLFuncBuilder b(origLoops[0]);
|
||||
FuncBuilder b(origLoops[0]);
|
||||
unsigned width = origLoops.size();
|
||||
|
||||
// Bounds for tile space loops.
|
||||
|
@ -161,7 +161,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
|
|||
|
||||
// Add intra-tile (or point) loops.
|
||||
for (unsigned i = 0; i < width; i++) {
|
||||
MLFuncBuilder b(topLoop);
|
||||
FuncBuilder b(topLoop);
|
||||
// Loop bounds will be set later.
|
||||
auto *pointLoop = b.createFor(loc, 0, 0);
|
||||
pointLoop->getBody()->getStatements().splice(
|
||||
|
@ -175,7 +175,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
|
|||
|
||||
// Add tile space loops;
|
||||
for (unsigned i = width; i < 2 * width; i++) {
|
||||
MLFuncBuilder b(topLoop);
|
||||
FuncBuilder b(topLoop);
|
||||
// Loop bounds will be set later.
|
||||
auto *tileSpaceLoop = b.createFor(loc, 0, 0);
|
||||
tileSpaceLoop->getBody()->getStatements().splice(
|
||||
|
|
|
@ -193,8 +193,8 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
// Insert the cleanup loop right after 'forStmt'.
|
||||
MLFuncBuilder builder(forStmt->getBlock(),
|
||||
std::next(StmtBlock::iterator(forStmt)));
|
||||
FuncBuilder builder(forStmt->getBlock(),
|
||||
std::next(StmtBlock::iterator(forStmt)));
|
||||
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
|
||||
cleanupForStmt->setLowerBoundMap(
|
||||
getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder));
|
||||
|
@ -214,8 +214,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
for (auto &subBlock : subBlocks) {
|
||||
// Builder to insert unroll-jammed bodies. Insert right at the end of
|
||||
// sub-block.
|
||||
MLFuncBuilder builder(subBlock.first->getBlock(),
|
||||
std::next(subBlock.second));
|
||||
FuncBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
|
||||
|
||||
// Unroll and jam (appends unrollJamFactor-1 additional copies).
|
||||
for (unsigned i = 1; i < unrollJamFactor; i++) {
|
||||
|
|
|
@ -64,7 +64,7 @@ using namespace mlir;
|
|||
///
|
||||
/// Prerequisites:
|
||||
/// `a` and `b` must be of IndexType.
|
||||
static mlir::Value *add(MLFuncBuilder *b, Location loc, Value *v, Value *w) {
|
||||
static Value *add(FuncBuilder *b, Location loc, Value *v, Value *w) {
|
||||
assert(v->getType().isa<IndexType>() && "v must be of IndexType");
|
||||
assert(w->getType().isa<IndexType>() && "w must be of IndexType");
|
||||
auto *context = b->getContext();
|
||||
|
@ -114,7 +114,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
|
|||
// forward them or define a whole rewriting chain based on MLFunctionBuilder
|
||||
// instead of Builer, the code for it would be duplicate boilerplate. As we
|
||||
// go towards unifying ML and CFG functions, this separation will disappear.
|
||||
MLFuncBuilder &b = *rewriter->getBuilder();
|
||||
FuncBuilder &b = *rewriter->getBuilder();
|
||||
|
||||
// 1. First allocate the local buffer in fast memory.
|
||||
// TODO(ntv): CL memory space.
|
||||
|
@ -234,7 +234,7 @@ struct LowerVectorTransfersPass
|
|||
std::unique_ptr<MLFuncGlobalLoweringState>
|
||||
makeFuncWiseState(MLFunction *f) const override {
|
||||
auto state = llvm::make_unique<LowerVectorTransfersState>();
|
||||
auto builder = MLFuncBuilder(f);
|
||||
auto builder = FuncBuilder(f);
|
||||
builder.setInsertionPointToStart(f->getBody());
|
||||
state->zero = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), 0);
|
||||
return state;
|
||||
|
|
|
@ -247,7 +247,7 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
|
|||
}
|
||||
|
||||
static OperationStmt *
|
||||
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
|
||||
instantiate(FuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
|
||||
DenseMap<const Value *, Value *> *substitutionsMap);
|
||||
|
||||
/// Not all Values belong to a program slice scoped within the immediately
|
||||
|
@ -265,7 +265,7 @@ static Value *substitute(Value *v, VectorType hwVectorType,
|
|||
if (it == substitutionsMap->end()) {
|
||||
auto *opStmt = cast<OperationStmt>(v->getDefiningOperation());
|
||||
if (opStmt->isa<ConstantOp>()) {
|
||||
MLFuncBuilder b(opStmt);
|
||||
FuncBuilder b(opStmt);
|
||||
auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap);
|
||||
auto res =
|
||||
substitutionsMap->insert(std::make_pair(v, inst->getResult(0)));
|
||||
|
@ -334,7 +334,7 @@ static Value *substitute(Value *v, VectorType hwVectorType,
|
|||
/// TODO(ntv): these implementation details should be captured in a
|
||||
/// vectorization trait at the op level directly.
|
||||
static SmallVector<mlir::Value *, 8>
|
||||
reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType,
|
||||
reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
|
||||
ArrayRef<unsigned> hwVectorInstance,
|
||||
ArrayRef<Value *> memrefIndices) {
|
||||
auto vectorShape = hwVectorType.getShape();
|
||||
|
@ -405,7 +405,7 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) {
|
|||
///
|
||||
/// If the underlying substitution fails, this fails too and returns nullptr.
|
||||
static OperationStmt *
|
||||
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
|
||||
instantiate(FuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
|
||||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
assert(!opStmt->isa<VectorTransferReadOp>() &&
|
||||
"Should call the function specialized for VectorTransferReadOp");
|
||||
|
@ -476,8 +476,8 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
|
|||
/// detailed description of the problem, see the description of
|
||||
/// reindexAffineIndices.
|
||||
static OperationStmt *
|
||||
instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
|
||||
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
|
||||
instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType,
|
||||
ArrayRef<unsigned> hwVectorInstance,
|
||||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
SmallVector<Value *, 8> indices =
|
||||
map(makePtrDynCaster<Value>(), read->getIndices());
|
||||
|
@ -496,7 +496,7 @@ instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
|
|||
/// detailed description of the problem, see the description of
|
||||
/// reindexAffineIndices.
|
||||
static OperationStmt *
|
||||
instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write,
|
||||
instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
|
||||
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
|
||||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
SmallVector<Value *, 8> indices =
|
||||
|
@ -543,7 +543,7 @@ static bool instantiateMaterialization(Statement *stmt,
|
|||
return stmt->emitError("NYI path IfStmt");
|
||||
|
||||
// Create a builder here for unroll-and-jam effects.
|
||||
MLFuncBuilder b(stmt);
|
||||
FuncBuilder b(stmt);
|
||||
auto *opStmt = cast<OperationStmt>(stmt);
|
||||
if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) {
|
||||
instantiate(&b, write, state->hwVectorType, state->hwVectorInstance,
|
||||
|
|
|
@ -81,7 +81,7 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
|
|||
/// a replacement cannot be performed.
|
||||
static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
|
||||
auto *forBody = forStmt->getBody();
|
||||
MLFuncBuilder bInner(forBody, forBody->begin());
|
||||
FuncBuilder bInner(forBody, forBody->begin());
|
||||
bInner.setInsertionPoint(forBody, forBody->begin());
|
||||
|
||||
// Doubles the shape with a leading dimension extent of 2.
|
||||
|
@ -101,7 +101,7 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
|
|||
auto newMemRefType = doubleShape(oldMemRefType);
|
||||
|
||||
// Put together alloc operands for the dynamic dimensions of the memref.
|
||||
MLFuncBuilder bOuter(forStmt);
|
||||
FuncBuilder bOuter(forStmt);
|
||||
SmallVector<Value *, 4> allocOperands;
|
||||
unsigned dynamicDimCount = 0;
|
||||
for (auto dimSize : oldMemRefType.getShape()) {
|
||||
|
@ -353,7 +353,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
|||
LLVM_DEBUG(
|
||||
// Tagging statements with shifts for debugging purposes.
|
||||
if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
|
||||
MLFuncBuilder b(opStmt);
|
||||
FuncBuilder b(opStmt);
|
||||
opStmt->setAttr(b.getIdentifier("shift"),
|
||||
b.getI64IntegerAttr(shifts[s - 1]));
|
||||
});
|
||||
|
|
|
@ -271,8 +271,8 @@ static void processMLFunction(MLFunction *fn,
|
|||
OwningRewritePatternList &&patterns) {
|
||||
class MLFuncRewriter : public WorklistRewriter {
|
||||
public:
|
||||
MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder)
|
||||
: WorklistRewriter(driver, builder.getContext()), builder(builder) {}
|
||||
MLFuncRewriter(GreedyPatternRewriteDriver &theDriver, FuncBuilder &builder)
|
||||
: WorklistRewriter(theDriver, builder.getContext()), builder(builder) {}
|
||||
|
||||
// Implement the hook for creating operations, and make sure that newly
|
||||
// created ops are added to the worklist for processing.
|
||||
|
@ -288,13 +288,13 @@ static void processMLFunction(MLFunction *fn,
|
|||
}
|
||||
|
||||
private:
|
||||
MLFuncBuilder &builder;
|
||||
FuncBuilder &builder;
|
||||
};
|
||||
|
||||
GreedyPatternRewriteDriver driver(std::move(patterns));
|
||||
fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
|
||||
|
||||
MLFuncBuilder mlBuilder(fn);
|
||||
FuncBuilder mlBuilder(fn);
|
||||
MLFuncRewriter rewriter(driver, mlBuilder);
|
||||
driver.simplifyFunction(fn, rewriter);
|
||||
}
|
||||
|
@ -303,8 +303,8 @@ static void processCFGFunction(CFGFunction *fn,
|
|||
OwningRewritePatternList &&patterns) {
|
||||
class CFGFuncRewriter : public WorklistRewriter {
|
||||
public:
|
||||
CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder)
|
||||
: WorklistRewriter(driver, builder.getContext()), builder(builder) {}
|
||||
CFGFuncRewriter(GreedyPatternRewriteDriver &theDriver, FuncBuilder &builder)
|
||||
: WorklistRewriter(theDriver, builder.getContext()), builder(builder) {}
|
||||
|
||||
// Implement the hook for creating operations, and make sure that newly
|
||||
// created ops are added to the worklist for processing.
|
||||
|
@ -320,7 +320,7 @@ static void processCFGFunction(CFGFunction *fn,
|
|||
}
|
||||
|
||||
private:
|
||||
CFGFuncBuilder &builder;
|
||||
FuncBuilder &builder;
|
||||
};
|
||||
|
||||
GreedyPatternRewriteDriver driver(std::move(patterns));
|
||||
|
@ -329,7 +329,7 @@ static void processCFGFunction(CFGFunction *fn,
|
|||
if (auto *opInst = dyn_cast<OperationStmt>(&op))
|
||||
driver.addToWorklist(opInst);
|
||||
|
||||
CFGFuncBuilder cfgBuilder(fn);
|
||||
FuncBuilder cfgBuilder(fn);
|
||||
CFGFuncRewriter rewriter(driver, cfgBuilder);
|
||||
driver.simplifyFunction(fn, rewriter);
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ using namespace mlir;
|
|||
/// the trip count can't be expressed as an affine expression.
|
||||
AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder) {
|
||||
FuncBuilder *builder) {
|
||||
auto lbMap = forStmt.getLowerBoundMap();
|
||||
|
||||
// Single result lower bound map only.
|
||||
|
@ -66,7 +66,7 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
|||
/// when the trip count can't be expressed as an affine expression.
|
||||
AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder) {
|
||||
FuncBuilder *builder) {
|
||||
auto lbMap = forStmt.getLowerBoundMap();
|
||||
|
||||
// Single result lower bound map only.
|
||||
|
@ -101,14 +101,14 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
|
|||
if (!forStmt->use_empty()) {
|
||||
if (forStmt->hasConstantLowerBound()) {
|
||||
auto *mlFunc = forStmt->getFunction();
|
||||
MLFuncBuilder topBuilder(&mlFunc->getBody()->front());
|
||||
FuncBuilder topBuilder(&mlFunc->getBody()->front());
|
||||
auto constOp = topBuilder.create<ConstantIndexOp>(
|
||||
forStmt->getLoc(), forStmt->getConstantLowerBound());
|
||||
forStmt->replaceAllUsesWith(constOp);
|
||||
} else {
|
||||
const AffineBound lb = forStmt->getLowerBound();
|
||||
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
|
||||
MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
|
||||
FuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
|
||||
auto affineApplyOp = builder.create<AffineApplyOp>(
|
||||
forStmt->getLoc(), lb.getMap(), lbOperands);
|
||||
forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
|
||||
|
@ -146,7 +146,7 @@ static ForStmt *
|
|||
generateLoop(AffineMap lbMap, AffineMap ubMap,
|
||||
const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
|
||||
&stmtGroupQueue,
|
||||
unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
|
||||
unsigned offset, ForStmt *srcForStmt, FuncBuilder *b) {
|
||||
SmallVector<Value *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
|
||||
SmallVector<Value *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
|
||||
|
||||
|
@ -167,7 +167,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
|
|||
// Generate the remapping if the shift is not zero: remappedIV = newIV -
|
||||
// shift.
|
||||
if (!srcForStmt->use_empty() && shift != 0) {
|
||||
auto b = MLFuncBuilder::getForStmtBodyBuilder(loopChunk);
|
||||
auto b = FuncBuilder::getForStmtBodyBuilder(loopChunk);
|
||||
auto *ivRemap = b.create<AffineApplyOp>(
|
||||
srcForStmt->getLoc(),
|
||||
b.getSingleDimShiftAffineMap(-static_cast<int64_t>(
|
||||
|
@ -261,7 +261,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
|
|||
|
||||
auto origLbMap = forStmt->getLowerBoundMap();
|
||||
uint64_t lbShift = 0;
|
||||
MLFuncBuilder b(forStmt);
|
||||
FuncBuilder b(forStmt);
|
||||
for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
|
||||
// If nothing is shifted by d, continue.
|
||||
if (sortedStmtGroups[d].empty())
|
||||
|
@ -379,7 +379,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
|
||||
if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
|
||||
FuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
|
||||
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
|
||||
auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
|
||||
assert(clLbMap &&
|
||||
|
@ -404,7 +404,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
|
||||
// Builder to insert unrolled bodies right after the last statement in the
|
||||
// body of 'forStmt'.
|
||||
MLFuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
|
||||
FuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
|
||||
|
||||
// Keep a pointer to the last statement in the original block so that we know
|
||||
// what to clone (since we are doing this in-place).
|
||||
|
|
|
@ -124,7 +124,7 @@ bool mlir::expandAffineApply(AffineApplyOp *op) {
|
|||
if (!op)
|
||||
return true;
|
||||
|
||||
FuncBuilder builder(op->getOperation());
|
||||
FuncBuilder builder(cast<OperationStmt>(op->getOperation()));
|
||||
auto affineMap = op->getAffineMap();
|
||||
for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) {
|
||||
Value *expanded = expandAffineExpr(&builder, numberedExpr.value(), op);
|
||||
|
|
|
@ -842,7 +842,7 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
|
|||
makePermutationMap(opStmt, state->strategy->loopToVectorDim);
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
|
||||
LLVM_DEBUG(permutationMap.print(dbgs()));
|
||||
MLFuncBuilder b(opStmt);
|
||||
FuncBuilder b(opStmt);
|
||||
auto transfer = b.create<VectorTransferReadOp>(
|
||||
opStmt->getLoc(), vectorType, memoryOp->getMemRef(),
|
||||
map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap);
|
||||
|
@ -970,7 +970,7 @@ static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
|
|||
!VectorType::isValidElementType(constant.getType())) {
|
||||
return nullptr;
|
||||
}
|
||||
MLFuncBuilder b(stmt);
|
||||
FuncBuilder b(stmt);
|
||||
Location loc = stmt->getLoc();
|
||||
auto vectorType = type.cast<VectorType>();
|
||||
auto attr = SplatElementsAttr::get(vectorType, constant.getValue());
|
||||
|
@ -1068,7 +1068,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt,
|
|||
/// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized.
|
||||
/// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
|
||||
/// do one-off logic here; ideally it would be TableGen'd.
|
||||
static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b,
|
||||
static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b,
|
||||
OperationStmt *opStmt,
|
||||
VectorizationState *state) {
|
||||
// Sanity checks.
|
||||
|
@ -1084,7 +1084,7 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b,
|
|||
auto *value = store->getValueToStore();
|
||||
auto *vectorValue = vectorizeOperand(value, opStmt, state);
|
||||
auto indices = map(makePtrDynCaster<Value>(), store->getIndices());
|
||||
MLFuncBuilder b(opStmt);
|
||||
FuncBuilder b(opStmt);
|
||||
auto permutationMap =
|
||||
makePermutationMap(opStmt, state->strategy->loopToVectorDim);
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
|
||||
|
@ -1159,7 +1159,7 @@ static bool vectorizeOperations(VectorizationState *state) {
|
|||
|
||||
// 2. Create vectorized form of the statement.
|
||||
// Insert it just before stmt, on success register stmt as replaced.
|
||||
MLFuncBuilder b(stmt);
|
||||
FuncBuilder b(stmt);
|
||||
auto *vectorizedStmt = vectorizeOneOperationStmt(&b, stmt, state);
|
||||
if (!vectorizedStmt) {
|
||||
return true;
|
||||
|
@ -1200,7 +1200,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
|
|||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
|
||||
continue;
|
||||
}
|
||||
MLFuncBuilder builder(loop); // builder to insert in place of loop
|
||||
FuncBuilder builder(loop); // builder to insert in place of loop
|
||||
DenseMap<const Value *, Value *> nomap;
|
||||
ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap));
|
||||
auto fail = doVectorize(m, &state);
|
||||
|
@ -1244,7 +1244,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
|
|||
if (fail) {
|
||||
return;
|
||||
}
|
||||
MLFuncBuilder b(stmt);
|
||||
FuncBuilder b(stmt);
|
||||
auto *res = vectorizeOneOperationStmt(&b, stmt, &state);
|
||||
if (res == nullptr) {
|
||||
fail = true;
|
||||
|
|
Loading…
Reference in New Issue