forked from OSchip/llvm-project
NFC: Remove the various "::getFunction" methods.
These methods assume that a function is a valid builtin top-level operation, and removing these methods allows for decoupling FuncOp and IR/. Utility "getParentOfType" methods have been added to Operation/OpState to allow for querying the first parent operation of a given type. PiperOrigin-RevId: 257018913
This commit is contained in:
parent
d3a85cc77d
commit
ce502af9cd
|
@ -136,7 +136,7 @@ public:
|
|||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Get or create the declaration of the printf function in the module.
|
||||
Function printfFunc = getPrintf(op->getFunction().getModule());
|
||||
Function printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
|
||||
|
||||
auto print = cast<toy::PrintOp>(op);
|
||||
auto loc = print.getLoc();
|
||||
|
|
|
@ -99,10 +99,6 @@ public:
|
|||
/// nullptr if this is a top-level block.
|
||||
Operation *getContainingOp();
|
||||
|
||||
/// Returns the function that this block is part of, even if the block is
|
||||
/// nested under an operation region.
|
||||
Function getFunction();
|
||||
|
||||
/// Insert this block (which must not already be in a function) right before
|
||||
/// the specified block.
|
||||
void insertBefore(Block *block);
|
||||
|
|
|
@ -71,6 +71,11 @@ public:
|
|||
/// Return the operation that this refers to.
|
||||
Operation *getOperation() { return state; }
|
||||
|
||||
/// Return the closes surrounding parent operation that is of type 'OpTy'.
|
||||
template <typename OpTy> OpTy getParentOfType() {
|
||||
return getOperation()->getParentOfType<OpTy>();
|
||||
}
|
||||
|
||||
/// Return the context this operation belongs to.
|
||||
MLIRContext *getContext() { return getOperation()->getContext(); }
|
||||
|
||||
|
|
|
@ -125,10 +125,14 @@ public:
|
|||
/// or nullptr if this is a top-level operation.
|
||||
Operation *getParentOp();
|
||||
|
||||
/// Returns the function that this operation is part of.
|
||||
/// The function is determined by traversing the chain of parent operations.
|
||||
/// Returns nullptr if the operation is unlinked.
|
||||
Function getFunction();
|
||||
/// Return the closest surrounding parent operation that is of type 'OpTy'.
|
||||
template <typename OpTy> OpTy getParentOfType() {
|
||||
auto *op = this;
|
||||
while ((op = op->getParentOp()))
|
||||
if (auto parentOp = llvm::dyn_cast<OpTy>(op))
|
||||
return parentOp;
|
||||
return OpTy();
|
||||
}
|
||||
|
||||
/// Replace any uses of 'from' with 'to' within this operation.
|
||||
void replaceUsesOfWith(Value *from, Value *to);
|
||||
|
|
|
@ -72,9 +72,6 @@ public:
|
|||
IRObjectWithUseList::replaceAllUsesWith(newValue);
|
||||
}
|
||||
|
||||
/// Return the function that this Value is defined in.
|
||||
Function getFunction();
|
||||
|
||||
/// If this value is the result of an operation, return the operation that
|
||||
/// defines it.
|
||||
Operation *getDefiningOp();
|
||||
|
@ -128,17 +125,11 @@ public:
|
|||
return const_cast<Value *>(value)->getKind() == Kind::BlockArgument;
|
||||
}
|
||||
|
||||
/// Return the function that this argument is defined in.
|
||||
Function getFunction();
|
||||
|
||||
Block *getOwner() { return owner; }
|
||||
|
||||
/// Returns the number of this argument.
|
||||
unsigned getArgNumber();
|
||||
|
||||
/// Returns if the current argument is a function argument.
|
||||
bool isFunctionArgument();
|
||||
|
||||
private:
|
||||
friend class Block; // For access to private constructor.
|
||||
BlockArgument(Type type, Block *owner)
|
||||
|
|
|
@ -307,7 +307,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) {
|
|||
if (inserted) {
|
||||
reorderedDims.push_back(v);
|
||||
}
|
||||
return getAffineDimExpr(iterPos->second, v->getFunction().getContext())
|
||||
return getAffineDimExpr(iterPos->second, v->getContext())
|
||||
.cast<AffineDimExpr>();
|
||||
}
|
||||
|
||||
|
|
|
@ -442,13 +442,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
Function mallocFunc =
|
||||
op->getFunction().getModule().getNamedFunction("malloc");
|
||||
op->getParentOfType<FuncOp>().getModule().getNamedFunction("malloc");
|
||||
if (!mallocFunc) {
|
||||
auto mallocType =
|
||||
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
|
||||
mallocFunc =
|
||||
Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
|
||||
op->getFunction().getModule().push_back(mallocFunc);
|
||||
op->getParentOfType<FuncOp>().getModule().push_back(mallocFunc);
|
||||
}
|
||||
|
||||
// Allocate the underlying buffer and store a pointer to it in the MemRef
|
||||
|
@ -503,11 +503,12 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
OperandAdaptor<DeallocOp> transformed(operands);
|
||||
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
Function freeFunc = op->getFunction().getModule().getNamedFunction("free");
|
||||
Function freeFunc =
|
||||
op->getParentOfType<FuncOp>().getModule().getNamedFunction("free");
|
||||
if (!freeFunc) {
|
||||
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
|
||||
freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
|
||||
op->getFunction().getModule().push_back(freeFunc);
|
||||
op->getParentOfType<FuncOp>().getModule().push_back(freeFunc);
|
||||
}
|
||||
|
||||
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
|
||||
|
|
|
@ -98,6 +98,6 @@ mlir_attr_t makeBoolAttr(mlir_context_t context, bool value) {
|
|||
}
|
||||
|
||||
unsigned getFunctionArity(mlir_func_t function) {
|
||||
auto *f = reinterpret_cast<mlir::Function *>(function);
|
||||
return f->getNumArguments();
|
||||
auto f = mlir::Function::getFromOpaquePointer(function);
|
||||
return f.getNumArguments();
|
||||
}
|
||||
|
|
|
@ -426,7 +426,7 @@ LogicalResult LaunchFuncOp::verify() {
|
|||
return emitOpError("attribute 'kernel' must be a function");
|
||||
}
|
||||
|
||||
auto module = getOperation()->getFunction().getModule();
|
||||
auto module = getParentOfType<ModuleOp>();
|
||||
Function kernelFunc = module.getNamedFunction(kernel());
|
||||
if (!kernelFunc)
|
||||
return emitError() << "kernel function '" << kernelAttr << "' is undefined";
|
||||
|
|
|
@ -64,7 +64,7 @@ static Function outlineKernelFunc(gpu::LaunchOp launchOp) {
|
|||
FunctionType type =
|
||||
FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
|
||||
std::string kernelFuncName =
|
||||
Twine(launchOp.getOperation()->getFunction().getName(), "_kernel").str();
|
||||
Twine(launchOp.getParentOfType<FuncOp>().getName(), "_kernel").str();
|
||||
Function outlinedFunc = Function::create(loc, kernelFuncName, type);
|
||||
outlinedFunc.getBody().takeBody(launchOp.getBody());
|
||||
Builder builder(launchOp.getContext());
|
||||
|
|
|
@ -1421,8 +1421,8 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
|
|||
os << ':';
|
||||
|
||||
// Print out some context information about the predecessors of this block.
|
||||
if (!block->getFunction()) {
|
||||
os << "\t// block is not in a function!";
|
||||
if (!block->getParent()) {
|
||||
os << "\t// block is not in a region!";
|
||||
} else if (block->hasNoPredecessors()) {
|
||||
os << "\t// no predecessors";
|
||||
} else if (auto *pred = block->getSinglePredecessor()) {
|
||||
|
|
|
@ -50,13 +50,8 @@ Operation *Block::getContainingOp() {
|
|||
return getParent() ? getParent()->getContainingOp() : nullptr;
|
||||
}
|
||||
|
||||
Function Block::getFunction() {
|
||||
auto *parent = getParent();
|
||||
return parent ? parent->getParentOfType<FuncOp>() : nullptr;
|
||||
}
|
||||
|
||||
/// Insert this block (which must not already be in a function) right before
|
||||
/// the specified block.
|
||||
/// Insert this block (which must not already be in a region) right before the
|
||||
/// specified block.
|
||||
void Block::insertBefore(Block *block) {
|
||||
assert(!getParent() && "already inserted into a block!");
|
||||
assert(block->getParent() && "cannot insert before a block without a parent");
|
||||
|
@ -254,11 +249,11 @@ void Block::walk(Block::iterator begin, Block::iterator end,
|
|||
/// invalidated.
|
||||
Block *Block::splitBlock(iterator splitBefore) {
|
||||
// Start by creating a new basic block, and insert it immediate after this
|
||||
// one in the containing function.
|
||||
// one in the containing region.
|
||||
auto newBB = new Block();
|
||||
getParent()->getBlocks().insert(std::next(Region::iterator(this)), newBB);
|
||||
|
||||
// Move all of the operations from the split point to the end of the function
|
||||
// Move all of the operations from the split point to the end of the region
|
||||
// into the new block.
|
||||
newBB->getOperations().splice(newBB->end(), getOperations(), splitBefore,
|
||||
end());
|
||||
|
|
|
@ -281,10 +281,6 @@ Operation *Operation::getParentOp() {
|
|||
return block ? block->getContainingOp() : nullptr;
|
||||
}
|
||||
|
||||
Function Operation::getFunction() {
|
||||
return block ? block->getFunction() : nullptr;
|
||||
}
|
||||
|
||||
/// Replace any uses of 'from' with 'to' within this operation.
|
||||
void Operation::replaceUsesOfWith(Value *from, Value *to) {
|
||||
if (from == to)
|
||||
|
|
|
@ -29,21 +29,9 @@ Operation *Value::getDefiningOp() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
/// Return the function that this Value is defined in.
|
||||
Function Value::getFunction() {
|
||||
switch (getKind()) {
|
||||
case Value::Kind::BlockArgument:
|
||||
return cast<BlockArgument>(this)->getFunction();
|
||||
case Value::Kind::OpResult:
|
||||
return getDefiningOp()->getFunction();
|
||||
}
|
||||
llvm_unreachable("Unknown Value Kind");
|
||||
}
|
||||
|
||||
Location Value::getLoc() {
|
||||
if (auto *op = getDefiningOp()) {
|
||||
if (auto *op = getDefiningOp())
|
||||
return op->getLoc();
|
||||
}
|
||||
return UnknownLoc::get(getContext());
|
||||
}
|
||||
|
||||
|
@ -78,20 +66,3 @@ void IRObjectWithUseList::dropAllUses() {
|
|||
use_begin()->drop();
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BlockArgument implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return the function that this argument is defined in.
|
||||
Function BlockArgument::getFunction() {
|
||||
if (auto *owner = getOwner())
|
||||
return owner->getFunction();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Returns if the current argument is a function argument.
|
||||
bool BlockArgument::isFunctionArgument() {
|
||||
auto containingFn = getFunction();
|
||||
return containingFn && &containingFn.front() == getOwner();
|
||||
}
|
||||
|
|
|
@ -170,7 +170,7 @@ public:
|
|||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto module = op->getFunction().getModule();
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
Function mallocFunc = module.getNamedFunction("malloc");
|
||||
if (!mallocFunc) {
|
||||
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
|
||||
|
@ -231,7 +231,7 @@ public:
|
|||
auto voidPtrTy =
|
||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
auto module = op->getFunction().getModule();
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
Function freeFunc = module.getNamedFunction("free");
|
||||
if (!freeFunc) {
|
||||
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
|
||||
|
@ -602,7 +602,7 @@ static Function getLLVMLibraryCallDeclaration(Operation *op,
|
|||
PatternRewriter &rewriter) {
|
||||
assert(isa<LinalgOp>(op));
|
||||
auto fnName = LinalgOp::getLibraryCallName();
|
||||
auto module = op->getFunction().getModule();
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
if (auto f = module.getNamedFunction(fnName)) {
|
||||
return f;
|
||||
}
|
||||
|
|
|
@ -431,8 +431,7 @@ static LogicalResult verify(CallOp op) {
|
|||
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
|
||||
if (!fnAttr)
|
||||
return op.emitOpError("requires a 'callee' function attribute");
|
||||
auto fn = op.getOperation()->getFunction().getModule().getNamedFunction(
|
||||
fnAttr.getValue());
|
||||
auto fn = op.getParentOfType<ModuleOp>().getNamedFunction(fnAttr.getValue());
|
||||
if (!fn)
|
||||
return op.emitOpError() << "'" << fnAttr.getValue()
|
||||
<< "' does not reference a valid function";
|
||||
|
@ -1098,8 +1097,8 @@ static LogicalResult verify(ConstantOp &op) {
|
|||
return op.emitOpError("requires 'value' to be a function reference");
|
||||
|
||||
// Try to find the referenced function.
|
||||
auto fn = op.getOperation()->getFunction().getModule().getNamedFunction(
|
||||
fnAttr.getValue());
|
||||
auto fn =
|
||||
op.getParentOfType<ModuleOp>().getNamedFunction(fnAttr.getValue());
|
||||
if (!fn)
|
||||
return op.emitOpError("reference to undefined function 'bar'");
|
||||
|
||||
|
@ -2029,7 +2028,9 @@ static void print(OpAsmPrinter *p, ReturnOp op) {
|
|||
}
|
||||
|
||||
static LogicalResult verify(ReturnOp op) {
|
||||
auto function = op.getOperation()->getFunction();
|
||||
// TODO(b/137008268): Return op should verify that it is nested directly
|
||||
// within a function operation.
|
||||
auto function = op.getParentOfType<FuncOp>();
|
||||
|
||||
// The operand number and types must match the function signature.
|
||||
const auto &results = function.getType().getResults();
|
||||
|
|
|
@ -217,8 +217,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs,
|
|||
|
||||
static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
|
||||
emitRemarkForBlock(Block &block) {
|
||||
auto *op = block.getContainingOp();
|
||||
return op ? op->emitRemark() : block.getFunction().emitRemark();
|
||||
return block.getContainingOp()->emitRemark();
|
||||
}
|
||||
|
||||
/// Creates a buffer in the faster memory space for the specified region;
|
||||
|
@ -250,7 +249,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block,
|
|||
OpBuilder &b = region.isWrite() ? epilogue : prologue;
|
||||
|
||||
// Builder to create constants at the top level.
|
||||
auto func = block->getFunction();
|
||||
auto func = block->getParent()->getParentOfType<FuncOp>();
|
||||
OpBuilder top(func.getBody());
|
||||
|
||||
auto loc = region.loc;
|
||||
|
@ -765,10 +764,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
|
|||
if (totalDmaBuffersSizeInBytes > fastMemCapacityBytes) {
|
||||
StringRef str = "Total size of all DMA buffers' for this block "
|
||||
"exceeds fast memory capacity\n";
|
||||
if (auto *op = block->getContainingOp())
|
||||
op->emitError(str);
|
||||
else
|
||||
block->getFunction().emitError(str);
|
||||
block->getContainingOp()->emitError(str);
|
||||
}
|
||||
|
||||
return totalDmaBuffersSizeInBytes;
|
||||
|
|
|
@ -859,7 +859,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
|
|||
// Create builder to insert alloc op just before 'forOp'.
|
||||
OpBuilder b(forInst);
|
||||
// Builder to create constants at the top level.
|
||||
OpBuilder top(forInst->getFunction().getBody());
|
||||
OpBuilder top(forInst->getParentOfType<FuncOp>().getBody());
|
||||
// Create new memref type based on slice bounds.
|
||||
auto *oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
|
||||
auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
|
||||
|
@ -1750,7 +1750,7 @@ public:
|
|||
};
|
||||
|
||||
// Search for siblings which load the same memref function argument.
|
||||
auto fn = dstNode->op->getFunction();
|
||||
auto fn = dstNode->op->getParentOfType<FuncOp>();
|
||||
for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
|
||||
for (auto *user : fn.getArgument(i)->getUsers()) {
|
||||
if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
|
||||
|
|
|
@ -635,8 +635,8 @@ static bool emitSlice(MaterializationState *state,
|
|||
}
|
||||
}
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
|
||||
LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs()));
|
||||
LLVM_DEBUG(dbgs() << "\nFunction is now\n");
|
||||
LLVM_DEBUG((*slice)[0]->getParentOfType<FuncOp>().print(dbgs()));
|
||||
|
||||
// slice are topologically sorted, we can just erase them in reverse
|
||||
// order. Reverse iterator does not just work simply with an operator*
|
||||
|
|
|
@ -125,7 +125,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
|
|||
Operation *op = forOp.getOperation();
|
||||
if (!iv->use_empty()) {
|
||||
if (forOp.hasConstantLowerBound()) {
|
||||
OpBuilder topBuilder(op->getFunction().getBody());
|
||||
OpBuilder topBuilder(op->getParentOfType<FuncOp>().getBody());
|
||||
auto constOp = topBuilder.create<ConstantIndexOp>(
|
||||
forOp.getLoc(), forOp.getConstantLowerBound());
|
||||
iv->replaceAllUsesWith(constOp);
|
||||
|
|
|
@ -81,11 +81,12 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
|||
std::unique_ptr<DominanceInfo> domInfo;
|
||||
std::unique_ptr<PostDominanceInfo> postDomInfo;
|
||||
if (domInstFilter)
|
||||
domInfo = llvm::make_unique<DominanceInfo>(domInstFilter->getFunction());
|
||||
domInfo = llvm::make_unique<DominanceInfo>(
|
||||
domInstFilter->getParentOfType<FuncOp>());
|
||||
|
||||
if (postDomInstFilter)
|
||||
postDomInfo =
|
||||
llvm::make_unique<PostDominanceInfo>(postDomInstFilter->getFunction());
|
||||
postDomInfo = llvm::make_unique<PostDominanceInfo>(
|
||||
postDomInstFilter->getParentOfType<FuncOp>());
|
||||
|
||||
// The ops where memref replacement succeeds are replaced with new ones.
|
||||
SmallVector<Operation *, 8> opsToErase;
|
||||
|
|
Loading…
Reference in New Issue