NFC: Remove the various "::getFunction" methods.

These methods assume that a function is a valid builtin top-level operation, and removing these methods allows for decoupling FuncOp and IR/. Utility "getParentOfType" methods have been added to Operation/OpState to allow for querying the first parent operation of a given type.

PiperOrigin-RevId: 257018913
This commit is contained in:
River Riddle 2019-07-08 11:20:26 -07:00 committed by jpienaar
parent d3a85cc77d
commit ce502af9cd
21 changed files with 52 additions and 95 deletions

View File

@ -136,7 +136,7 @@ public:
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Get or create the declaration of the printf function in the module.
Function printfFunc = getPrintf(op->getFunction().getModule());
Function printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
auto print = cast<toy::PrintOp>(op);
auto loc = print.getLoc();

View File

@ -99,10 +99,6 @@ public:
/// nullptr if this is a top-level block.
Operation *getContainingOp();
/// Returns the function that this block is part of, even if the block is
/// nested under an operation region.
Function getFunction();
/// Insert this block (which must not already be in a function) right before
/// the specified block.
void insertBefore(Block *block);

View File

@ -71,6 +71,11 @@ public:
/// Return the operation that this refers to.
Operation *getOperation() { return state; }
/// Return the closes surrounding parent operation that is of type 'OpTy'.
template <typename OpTy> OpTy getParentOfType() {
return getOperation()->getParentOfType<OpTy>();
}
/// Return the context this operation belongs to.
MLIRContext *getContext() { return getOperation()->getContext(); }

View File

@ -125,10 +125,14 @@ public:
/// or nullptr if this is a top-level operation.
Operation *getParentOp();
/// Returns the function that this operation is part of.
/// The function is determined by traversing the chain of parent operations.
/// Returns nullptr if the operation is unlinked.
Function getFunction();
/// Return the closest surrounding parent operation that is of type 'OpTy'.
template <typename OpTy> OpTy getParentOfType() {
auto *op = this;
while ((op = op->getParentOp()))
if (auto parentOp = llvm::dyn_cast<OpTy>(op))
return parentOp;
return OpTy();
}
/// Replace any uses of 'from' with 'to' within this operation.
void replaceUsesOfWith(Value *from, Value *to);

View File

@ -72,9 +72,6 @@ public:
IRObjectWithUseList::replaceAllUsesWith(newValue);
}
/// Return the function that this Value is defined in.
Function getFunction();
/// If this value is the result of an operation, return the operation that
/// defines it.
Operation *getDefiningOp();
@ -128,17 +125,11 @@ public:
return const_cast<Value *>(value)->getKind() == Kind::BlockArgument;
}
/// Return the function that this argument is defined in.
Function getFunction();
Block *getOwner() { return owner; }
/// Returns the number of this argument.
unsigned getArgNumber();
/// Returns if the current argument is a function argument.
bool isFunctionArgument();
private:
friend class Block; // For access to private constructor.
BlockArgument(Type type, Block *owner)

View File

@ -307,7 +307,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) {
if (inserted) {
reorderedDims.push_back(v);
}
return getAffineDimExpr(iterPos->second, v->getFunction().getContext())
return getAffineDimExpr(iterPos->second, v->getContext())
.cast<AffineDimExpr>();
}

View File

@ -442,13 +442,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Insert the `malloc` declaration if it is not already present.
Function mallocFunc =
op->getFunction().getModule().getNamedFunction("malloc");
op->getParentOfType<FuncOp>().getModule().getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType =
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
mallocFunc =
Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
op->getFunction().getModule().push_back(mallocFunc);
op->getParentOfType<FuncOp>().getModule().push_back(mallocFunc);
}
// Allocate the underlying buffer and store a pointer to it in the MemRef
@ -503,11 +503,12 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present.
Function freeFunc = op->getFunction().getModule().getNamedFunction("free");
Function freeFunc =
op->getParentOfType<FuncOp>().getModule().getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
op->getFunction().getModule().push_back(freeFunc);
op->getParentOfType<FuncOp>().getModule().push_back(freeFunc);
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();

View File

@ -98,6 +98,6 @@ mlir_attr_t makeBoolAttr(mlir_context_t context, bool value) {
}
unsigned getFunctionArity(mlir_func_t function) {
auto *f = reinterpret_cast<mlir::Function *>(function);
return f->getNumArguments();
auto f = mlir::Function::getFromOpaquePointer(function);
return f.getNumArguments();
}

View File

@ -426,7 +426,7 @@ LogicalResult LaunchFuncOp::verify() {
return emitOpError("attribute 'kernel' must be a function");
}
auto module = getOperation()->getFunction().getModule();
auto module = getParentOfType<ModuleOp>();
Function kernelFunc = module.getNamedFunction(kernel());
if (!kernelFunc)
return emitError() << "kernel function '" << kernelAttr << "' is undefined";

View File

@ -64,7 +64,7 @@ static Function outlineKernelFunc(gpu::LaunchOp launchOp) {
FunctionType type =
FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
std::string kernelFuncName =
Twine(launchOp.getOperation()->getFunction().getName(), "_kernel").str();
Twine(launchOp.getParentOfType<FuncOp>().getName(), "_kernel").str();
Function outlinedFunc = Function::create(loc, kernelFuncName, type);
outlinedFunc.getBody().takeBody(launchOp.getBody());
Builder builder(launchOp.getContext());

View File

@ -1421,8 +1421,8 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
os << ':';
// Print out some context information about the predecessors of this block.
if (!block->getFunction()) {
os << "\t// block is not in a function!";
if (!block->getParent()) {
os << "\t// block is not in a region!";
} else if (block->hasNoPredecessors()) {
os << "\t// no predecessors";
} else if (auto *pred = block->getSinglePredecessor()) {

View File

@ -50,13 +50,8 @@ Operation *Block::getContainingOp() {
return getParent() ? getParent()->getContainingOp() : nullptr;
}
Function Block::getFunction() {
auto *parent = getParent();
return parent ? parent->getParentOfType<FuncOp>() : nullptr;
}
/// Insert this block (which must not already be in a function) right before
/// the specified block.
/// Insert this block (which must not already be in a region) right before the
/// specified block.
void Block::insertBefore(Block *block) {
assert(!getParent() && "already inserted into a block!");
assert(block->getParent() && "cannot insert before a block without a parent");
@ -254,11 +249,11 @@ void Block::walk(Block::iterator begin, Block::iterator end,
/// invalidated.
Block *Block::splitBlock(iterator splitBefore) {
// Start by creating a new basic block, and insert it immediate after this
// one in the containing function.
// one in the containing region.
auto newBB = new Block();
getParent()->getBlocks().insert(std::next(Region::iterator(this)), newBB);
// Move all of the operations from the split point to the end of the function
// Move all of the operations from the split point to the end of the region
// into the new block.
newBB->getOperations().splice(newBB->end(), getOperations(), splitBefore,
end());

View File

@ -281,10 +281,6 @@ Operation *Operation::getParentOp() {
return block ? block->getContainingOp() : nullptr;
}
Function Operation::getFunction() {
return block ? block->getFunction() : nullptr;
}
/// Replace any uses of 'from' with 'to' within this operation.
void Operation::replaceUsesOfWith(Value *from, Value *to) {
if (from == to)

View File

@ -29,21 +29,9 @@ Operation *Value::getDefiningOp() {
return nullptr;
}
/// Return the function that this Value is defined in.
Function Value::getFunction() {
switch (getKind()) {
case Value::Kind::BlockArgument:
return cast<BlockArgument>(this)->getFunction();
case Value::Kind::OpResult:
return getDefiningOp()->getFunction();
}
llvm_unreachable("Unknown Value Kind");
}
Location Value::getLoc() {
if (auto *op = getDefiningOp()) {
if (auto *op = getDefiningOp())
return op->getLoc();
}
return UnknownLoc::get(getContext());
}
@ -78,20 +66,3 @@ void IRObjectWithUseList::dropAllUses() {
use_begin()->drop();
}
}
//===----------------------------------------------------------------------===//
// BlockArgument implementation.
//===----------------------------------------------------------------------===//
/// Return the function that this argument is defined in.
Function BlockArgument::getFunction() {
if (auto *owner = getOwner())
return owner->getFunction();
return nullptr;
}
/// Returns if the current argument is a function argument.
bool BlockArgument::isFunctionArgument() {
auto containingFn = getFunction();
return containingFn && &containingFn.front() == getOwner();
}

View File

@ -170,7 +170,7 @@ public:
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Insert the `malloc` declaration if it is not already present.
auto module = op->getFunction().getModule();
auto module = op->getParentOfType<ModuleOp>();
Function mallocFunc = module.getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
@ -231,7 +231,7 @@ public:
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
auto module = op->getFunction().getModule();
auto module = op->getParentOfType<ModuleOp>();
Function freeFunc = module.getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
@ -602,7 +602,7 @@ static Function getLLVMLibraryCallDeclaration(Operation *op,
PatternRewriter &rewriter) {
assert(isa<LinalgOp>(op));
auto fnName = LinalgOp::getLibraryCallName();
auto module = op->getFunction().getModule();
auto module = op->getParentOfType<ModuleOp>();
if (auto f = module.getNamedFunction(fnName)) {
return f;
}

View File

@ -431,8 +431,7 @@ static LogicalResult verify(CallOp op) {
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' function attribute");
auto fn = op.getOperation()->getFunction().getModule().getNamedFunction(
fnAttr.getValue());
auto fn = op.getParentOfType<ModuleOp>().getNamedFunction(fnAttr.getValue());
if (!fn)
return op.emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
@ -1098,8 +1097,8 @@ static LogicalResult verify(ConstantOp &op) {
return op.emitOpError("requires 'value' to be a function reference");
// Try to find the referenced function.
auto fn = op.getOperation()->getFunction().getModule().getNamedFunction(
fnAttr.getValue());
auto fn =
op.getParentOfType<ModuleOp>().getNamedFunction(fnAttr.getValue());
if (!fn)
return op.emitOpError("reference to undefined function 'bar'");
@ -2029,7 +2028,9 @@ static void print(OpAsmPrinter *p, ReturnOp op) {
}
static LogicalResult verify(ReturnOp op) {
auto function = op.getOperation()->getFunction();
// TODO(b/137008268): Return op should verify that it is nested directly
// within a function operation.
auto function = op.getParentOfType<FuncOp>();
// The operand number and types must match the function signature.
const auto &results = function.getType().getResults();

View File

@ -217,8 +217,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs,
static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
emitRemarkForBlock(Block &block) {
auto *op = block.getContainingOp();
return op ? op->emitRemark() : block.getFunction().emitRemark();
return block.getContainingOp()->emitRemark();
}
/// Creates a buffer in the faster memory space for the specified region;
@ -250,7 +249,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
OpBuilder &b = region.isWrite() ? epilogue : prologue;
// Builder to create constants at the top level.
auto func = block->getFunction();
auto func = block->getParent()->getParentOfType<FuncOp>();
OpBuilder top(func.getBody());
auto loc = region.loc;
@ -765,10 +764,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
if (totalDmaBuffersSizeInBytes > fastMemCapacityBytes) {
StringRef str = "Total size of all DMA buffers' for this block "
"exceeds fast memory capacity\n";
if (auto *op = block->getContainingOp())
op->emitError(str);
else
block->getFunction().emitError(str);
block->getContainingOp()->emitError(str);
}
return totalDmaBuffersSizeInBytes;

View File

@ -859,7 +859,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
// Create builder to insert alloc op just before 'forOp'.
OpBuilder b(forInst);
// Builder to create constants at the top level.
OpBuilder top(forInst->getFunction().getBody());
OpBuilder top(forInst->getParentOfType<FuncOp>().getBody());
// Create new memref type based on slice bounds.
auto *oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
@ -1750,7 +1750,7 @@ public:
};
// Search for siblings which load the same memref function argument.
auto fn = dstNode->op->getFunction();
auto fn = dstNode->op->getParentOfType<FuncOp>();
for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
for (auto *user : fn.getArgument(i)->getUsers()) {
if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {

View File

@ -635,8 +635,8 @@ static bool emitSlice(MaterializationState *state,
}
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs()));
LLVM_DEBUG(dbgs() << "\nFunction is now\n");
LLVM_DEBUG((*slice)[0]->getParentOfType<FuncOp>().print(dbgs()));
// slice are topologically sorted, we can just erase them in reverse
// order. Reverse iterator does not just work simply with an operator*

View File

@ -125,7 +125,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
Operation *op = forOp.getOperation();
if (!iv->use_empty()) {
if (forOp.hasConstantLowerBound()) {
OpBuilder topBuilder(op->getFunction().getBody());
OpBuilder topBuilder(op->getParentOfType<FuncOp>().getBody());
auto constOp = topBuilder.create<ConstantIndexOp>(
forOp.getLoc(), forOp.getConstantLowerBound());
iv->replaceAllUsesWith(constOp);

View File

@ -81,11 +81,12 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
std::unique_ptr<DominanceInfo> domInfo;
std::unique_ptr<PostDominanceInfo> postDomInfo;
if (domInstFilter)
domInfo = llvm::make_unique<DominanceInfo>(domInstFilter->getFunction());
domInfo = llvm::make_unique<DominanceInfo>(
domInstFilter->getParentOfType<FuncOp>());
if (postDomInstFilter)
postDomInfo =
llvm::make_unique<PostDominanceInfo>(postDomInstFilter->getFunction());
postDomInfo = llvm::make_unique<PostDominanceInfo>(
postDomInstFilter->getParentOfType<FuncOp>());
// The ops where memref replacement succeeds are replaced with new ones.
SmallVector<Operation *, 8> opsToErase;