Rename findFunction from the ML side of the house to be named getFunction(),

making it more similar to the CFG side of things.  It is true that in a deeply
nested case that this is not a guaranteed O(1) time operation, and that 'get'
could lead compiler hackers to think this is cheap, but we need to merge these
and we can look into solutions for this in the future if it becomes a problem
in practice.

This is step 9/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 226983931
This commit is contained in:
Chris Lattner 2018-12-26 21:13:45 -08:00 committed by jpienaar
parent 4e5337601e
commit abf72a8bb1
13 changed files with 29 additions and 29 deletions

View File

@ -289,20 +289,20 @@ public:
/// Create ML function builder and set insertion point to the given statement,
/// which will cause subsequent insertions to go right before it.
MLFuncBuilder(Statement *stmt)
// TODO: Eliminate findFunction from this.
: MLFuncBuilder(stmt->findFunction()) {
// TODO: Eliminate getFunction from this.
: MLFuncBuilder(stmt->getFunction()) {
setInsertionPoint(stmt);
}
MLFuncBuilder(StmtBlock *block)
// TODO: Eliminate findFunction from this.
: MLFuncBuilder(block->findFunction()) {
// TODO: Eliminate getFunction from this.
: MLFuncBuilder(block->getFunction()) {
setInsertionPoint(block, block->end());
}
MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
// TODO: Eliminate findFunction from this.
: MLFuncBuilder(block->findFunction()) {
// TODO: Eliminate getFunction from this.
: MLFuncBuilder(block->getFunction()) {
setInsertionPoint(block, insertPoint);
}

View File

@ -104,7 +104,7 @@ public:
/// Returns the function that this statement is part of.
/// The function is determined by traversing the chain of parent statements.
/// Returns nullptr if the statement is unlinked.
MLFunction *findFunction() const;
MLFunction *getFunction() const;
/// Destroys this statement and its subclass data.
void destroy();

View File

@ -290,7 +290,7 @@ public:
}
/// Resolve base class ambiguity.
using Statement::findFunction;
using Statement::getFunction;
/// Operand iterators.
using operand_iterator = OperandIterator<ForStmt, MLValue>;

View File

@ -62,11 +62,11 @@ public:
return const_cast<StmtBlock *>(this)->getContainingStmt();
}
/// Returns the function that this statement block is part of.
/// The function is determined by traversing the chain of parent statements.
MLFunction *findFunction();
const MLFunction *findFunction() const {
return const_cast<StmtBlock *>(this)->findFunction();
/// Returns the function that this statement block is part of. The function
/// is determined by traversing the chain of parent statements.
MLFunction *getFunction();
const MLFunction *getFunction() const {
return const_cast<StmtBlock *>(this)->getFunction();
}
//===--------------------------------------------------------------------===//

View File

@ -39,7 +39,7 @@ bool mlir::properlyDominates(const Statement &a, const Statement &b) {
if (&a == &b)
return false;
if (a.findFunction() != b.findFunction())
if (a.getFunction() != b.getFunction())
return false;
if (a.getBlock() == b.getBlock()) {

View File

@ -1015,7 +1015,7 @@ protected:
case SSAValueKind::BlockArgument:
// If this is an argument to the function, give it an 'arg' name.
if (auto *block = cast<BlockArgument>(value)->getOwner())
if (auto *fn = block->findFunction())
if (auto *fn = block->getFunction())
if (&fn->getBlockList().front() == block) {
specialName << "arg" << nextArgumentID++;
break;
@ -1639,7 +1639,7 @@ void BasicBlock::printAsOperand(raw_ostream &os, bool printType) {
}
void Statement::print(raw_ostream &os) const {
MLFunction *function = findFunction();
MLFunction *function = getFunction();
if (!function) {
os << "<<UNLINKED STATEMENT>>\n";
return;
@ -1653,7 +1653,7 @@ void Statement::print(raw_ostream &os) const {
void Statement::dump() const { print(llvm::errs()); }
void StmtBlock::printBlock(raw_ostream &os) const {
const MLFunction *function = findFunction();
const MLFunction *function = getFunction();
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
MLFunctionPrinter(function, modulePrinter).print(this);

View File

@ -474,7 +474,7 @@ void ReturnOp::print(OpAsmPrinter *p) const {
bool ReturnOp::verify() const {
const Function *function;
if (auto *stmt = dyn_cast<OperationStmt>(getOperation()))
function = stmt->getBlock()->findFunction();
function = stmt->getFunction();
else
function = cast<Instruction>(getOperation())->getFunction();

View File

@ -99,7 +99,7 @@ void Operation::setLoc(Location loc) {
Function *Operation::getOperationFunction() {
if (auto *inst = llvm::dyn_cast<Instruction>(this))
return inst->getFunction();
return llvm::cast<OperationStmt>(this)->findFunction();
return llvm::cast<OperationStmt>(this)->getFunction();
}
/// Return the number of operands this operation has.

View File

@ -57,9 +57,9 @@ Function *SSAValue::getFunction() {
case SSAValueKind::BlockArgument:
return cast<BlockArgument>(this)->getFunction();
case SSAValueKind::StmtResult:
return getDefiningStmt()->findFunction();
return getDefiningStmt()->getFunction();
case SSAValueKind::ForStmt:
return cast<ForStmt>(this)->findFunction();
return cast<ForStmt>(this)->getFunction();
}
}
@ -121,6 +121,6 @@ MLFunction *MLValue::getFunction() {
/// Return the function that this argument is defined in.
MLFunction *BlockArgument::getFunction() {
if (auto *owner = getOwner())
return owner->findFunction();
return owner->getFunction();
return nullptr;
}

View File

@ -81,8 +81,8 @@ Statement *Statement::getParentStmt() const {
return block ? block->getContainingStmt() : nullptr;
}
MLFunction *Statement::findFunction() const {
return block ? block->findFunction() : nullptr;
MLFunction *Statement::getFunction() const {
return block ? block->getFunction() : nullptr;
}
MLValue *Statement::getOperand(unsigned idx) {
@ -368,7 +368,7 @@ MLIRContext *OperationStmt::getContext() const {
// In the very odd case where we have no operands or results, fall back to
// doing a find.
return findFunction()->getContext();
return getFunction()->getContext();
}
bool OperationStmt::isReturn() const { return isa<ReturnOp>(); }
@ -560,7 +560,7 @@ MLIRContext *IfStmt::getContext() const {
// Check for degenerate case of if statement with no operands.
// This is unlikely, but legal.
if (operands.empty())
return findFunction()->getContext();
return getFunction()->getContext();
return getOperand(0)->getType().getContext();
}

View File

@ -32,7 +32,7 @@ Statement *StmtBlock::getContainingStmt() {
return parent ? parent->getContainingStmt() : nullptr;
}
MLFunction *StmtBlock::findFunction() {
MLFunction *StmtBlock::getFunction() {
StmtBlock *block = this;
while (auto *stmt = block->getContainingStmt()) {
block = stmt->getBlock();

View File

@ -180,7 +180,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
MLFuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
// Builder to create constants at the top level.
MLFuncBuilder top(forStmt->findFunction());
MLFuncBuilder top(forStmt->getFunction());
auto loc = forStmt->getLoc();
auto *memref = region.memref;

View File

@ -101,7 +101,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
// Replaces all IV uses to its single iteration value.
if (!forStmt->use_empty()) {
if (forStmt->hasConstantLowerBound()) {
auto *mlFunc = forStmt->findFunction();
auto *mlFunc = forStmt->getFunction();
MLFuncBuilder topBuilder(&mlFunc->getBody()->front());
auto constOp = topBuilder.create<ConstantIndexOp>(
forStmt->getLoc(), forStmt->getConstantLowerBound());